From 28b2545163e0bb777d294a4e7a5bd4c77fcaa50b Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Thu, 22 Aug 2024 16:16:02 +0200 Subject: [PATCH] test for connecting clients and bidirectional communication to agent. Required lots of rework since the GOBChannel appeared to be reading ahead of the data it actually needed. Now using more low-level IO to send the clientId over to the agent instead. --- pkg/comms/agentlistener.go | 4 +- pkg/comms/agentlistener_test.go | 4 +- pkg/comms/agentserver.go | 30 +++++--- pkg/comms/agentserver_test.go | 12 ++++ pkg/comms/events.go | 4 -- pkg/server/admin/admin.go | 7 +- pkg/server/admin/admin_test.go | 71 +++++++++++++++++-- pkg/support/iowrappers/readwriteaddrcloser.go | 28 ++++++++ pkg/testsupport/bitpipe.go | 6 +- pkg/testsupport/utils.go | 15 ++++ 10 files changed, 150 insertions(+), 31 deletions(-) diff --git a/pkg/comms/agentlistener.go b/pkg/comms/agentlistener.go index 95dabb1..6290475 100644 --- a/pkg/comms/agentlistener.go +++ b/pkg/comms/agentlistener.go @@ -36,12 +36,12 @@ func (listener AgentListener) Accept() (net.Conn, error) { return nil, err } - clientInfo, err := ReceiveClientInfo(conn) + clientId, err := ReceiveClientInfo(conn) if err != nil { conn.Close() return nil, err } - return NewLocalAddrHackConn(conn, clientInfo.ClientId), nil + return NewLocalAddrHackConn(conn, clientId), nil } func (listener AgentListener) Close() error { diff --git a/pkg/comms/agentlistener_test.go b/pkg/comms/agentlistener_test.go index 547d78b..abcea06 100644 --- a/pkg/comms/agentlistener_test.go +++ b/pkg/comms/agentlistener_test.go @@ -19,9 +19,7 @@ func (s *AgentServerTestSuite) Test_clientIdPassedAsLocalAddr() { func() any { connection, err := serverChannel.Session.OpenStream() s.Nil(err) - gobChannel := NewGOBChannel(connection) - clientInfo := ClientInfo{ClientId: clientId} - err = SendWithTimeout(gobChannel, clientInfo) + err = SendClientInfo(connection, clientId) s.Nil(err) return nil }) diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index 076fb4d..38ada3c 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -1,6 +1,7 @@ package comms import ( + "encoding/binary" "fmt" "github.com/hashicorp/yamux" "io" @@ -226,18 +227,31 @@ func CheckProtocolVersion(role Role, channel GOBChannel) error { // decorates the yamux Session (which is a listener) and uses this connection to exchange some // metadata before the connection is handed back to SSH. -func SendClientInfo(conn io.ReadWriter, info ClientInfo) error { - channel := NewGOBChannel(conn) - return SendWithTimeout(channel, info) +// Cannot use GOB for sending clientinfo since this involves mixing of buffered reads by +// GOB with ather reads. Alternatively, we could wrap the GOB message and encode its length, +// and then read the exacct number of bytes when decodeing. But since the clientInfo is just +// a string, this is easier. +func SendClientInfo(conn io.Writer, info string) error { + err := binary.Write(conn, binary.BigEndian, uint32(len(info))) + if err != nil { + return err + } + _, err = conn.Write([]byte(info)) + return err } -func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) { - channel := NewGOBChannel(conn) - clientInfo, err := ReceiveWithTimeout[ClientInfo](channel) +func ReceiveClientInfo(conn io.Reader) (string, error) { + var length uint32 + err := binary.Read(conn, binary.BigEndian, &length) if err != nil { - return ClientInfo{}, err + return "", err } - return clientInfo, nil + bytes := make([]byte, length) + _, err = io.ReadFull(conn, bytes) + if err != nil { + return "", err + } + return string(bytes), nil } // message sent on the initial connection from server to agent to confirm the registration diff --git a/pkg/comms/agentserver_test.go b/pkg/comms/agentserver_test.go index dc22798..2b1a5d7 100644 --- a/pkg/comms/agentserver_test.go +++ b/pkg/comms/agentserver_test.go @@ -1,6 +1,7 @@ package comms import ( + "bytes" "context" "git.wamblee.org/converge/pkg/testsupport" "github.com/stretchr/testify/suite" @@ -57,6 +58,17 @@ func TestAgentServerTestSuite(t *testing.T) { suite.Run(t, &AgentServerTestSuite{}) } +func (s *AgentServerTestSuite) TestClientInfoEncodeDecode() { + for _, clientId := range []string{"abc", "djkdfadfha"} { + buf := &bytes.Buffer{} + err := SendClientInfo(buf, clientId) + s.Nil(err) + clientIdReceived, err := ReceiveClientInfo(buf) + s.Nil(err) + s.Equal(clientId, clientIdReceived) + } +} + func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) { commChannels := testsupport.RunAndWait( &s.Suite, diff --git a/pkg/comms/events.go b/pkg/comms/events.go index 0a12365..ec049a1 100644 --- a/pkg/comms/events.go +++ b/pkg/comms/events.go @@ -29,10 +29,6 @@ type EnvironmentInfo struct { Shell string } -type ClientInfo struct { - ClientId string -} - type SessionInfo struct { ClientId string diff --git a/pkg/server/admin/admin.go b/pkg/server/admin/admin.go index 3e552c6..e74ac2d 100644 --- a/pkg/server/admin/admin.go +++ b/pkg/server/admin/admin.go @@ -166,16 +166,11 @@ func (admin *Admin) AddClient(publicId models.RendezVousId, clientConn iowrapper return nil, err } log.Println("Successful websocket connection to agent") - - log.Println("Sending connection information to agent") - client := newClient(publicId, clientConn, agentConn, agent.Info.Guid) // Before using this connection for SSH we use it to send client metadata to the // agent - err = comms.SendClientInfo(agentConn, comms.ClientInfo{ - ClientId: string(client.Info.ClientId), - }) + err = comms.SendClientInfo(agentConn, string(client.Info.ClientId)) if err != nil { return nil, err } diff --git a/pkg/server/admin/admin_test.go b/pkg/server/admin/admin_test.go index e2f3c38..31942c0 100644 --- a/pkg/server/admin/admin_test.go +++ b/pkg/server/admin/admin_test.go @@ -6,6 +6,7 @@ import ( "fmt" "git.wamblee.org/converge/pkg/comms" "git.wamblee.org/converge/pkg/models" + "git.wamblee.org/converge/pkg/support/iowrappers" "git.wamblee.org/converge/pkg/testsupport" "github.com/stretchr/testify/suite" "io" @@ -66,7 +67,7 @@ func (s *AdminTestSuite) SetupTest() { s.NotNil(rand.Read(s.hostKey)) } -func (suite *AdminTestSuite) TearDownTest() { +func (s *AdminTestSuite) TearDownTest() { } func TestAdminTestSuite(t *testing.T) { @@ -80,6 +81,7 @@ type AddAgentResult struct { type AgentRegisterResult struct { registration comms.AgentRegistration + commChannel comms.CommChannel err error } @@ -95,7 +97,7 @@ func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId stri } }, func() any { - res := s.agentRegistration(assignedPublicId, agentRW) + res := s.agentRegistration(agentRW) if assignedPublicId != "" { s.Nil(res.err) s.True(res.registration.Ok) @@ -148,12 +150,71 @@ func (s *AdminTestSuite) Test_agentDuplicateId() { } res, agentSideResult := s.agentRegisters("abc", "") s.NotNil(res.err) - // verify it is the correct error an dnot an id mismatch. + // verify it is the correct error and not an id mismatch. s.True(strings.Contains(res.err.Error(), "could not allocate a new unique id")) s.False(agentSideResult.registration.Ok) } -func (s *AdminTestSuite) agentRegistration(expectedPublicId string, agentRW io.ReadWriteCloser) AgentRegisterResult { +func (s *AdminTestSuite) Test_connectClient() { + publicId := "abc" + serverRes, agentRes := s.agentRegisters(publicId, "abc") + s.Nil(serverRes.err) + s.Nil(agentRes.err) + + serverToClientRW, clientToServerRW := s.createPipe() + + data := "connect client test msg" + res := testsupport.RunAndWait( + &s.Suite, + func() any { + // server + clientConn, err := s.admin.AddClient(models.RendezVousId(publicId), + iowrappers.NewSimpleReadWriteAddrCloser(serverToClientRW, testsupport.DummyRemoteAddr("remoteaddr"))) + s.Nil(err) + // Connection to agent over yamux + serverToAgentYamux := clientConn.agentConnection + // test by sending a message to the agent. + testsupport.AssertWriteData(&s.Suite, data, serverToAgentYamux) + return clientConn + }, + func() any { + // agent + listener := comms.NewAgentListener(agentRes.commChannel.Session) + //.Connection from server over yamux + agentToServerYamux, err := listener.Accept() + s.Nil(err) + // Test by receiving a message from the server + testsupport.AssertReadData(&s.Suite, data, agentToServerYamux) + return agentToServerYamux + }) + + // Now we need to verify bi-directional communication between client and agent through the wserv + + clientConn := res[0].(*clientConnection) + go clientConn.Synchronize() + + agentToServerYamux := res[1].(io.ReadWriter) + + data1 := "mytestdata" + data2 := "mytestdata-2" + testsupport.RunAndWait( + &s.Suite, + func() any { + testsupport.AssertWriteData(&s.Suite, data1, clientToServerRW) + testsupport.AssertReadData(&s.Suite, data2, agentToServerYamux) + return nil + }, + func() any { + testsupport.AssertReadData(&s.Suite, data1, agentToServerYamux) + testsupport.AssertWriteData(&s.Suite, data2, clientToServerRW) + return nil + }) + + // will close the connections and as a result also th synchronize goroutine. + s.cancelFunc() +} + +func (s *AdminTestSuite) agentRegistration(agentRW io.ReadWriteCloser) AgentRegisterResult { // verify registration message received agentRegistration, err := comms.ReceiveRegistrationMessage(agentRW) if err != nil { @@ -164,7 +225,7 @@ func (s *AdminTestSuite) agentRegistration(expectedPublicId string, agentRW io.R return AgentRegisterResult{registration: agentRegistration, err: err} } s.NotNil(commChannel) - return AgentRegisterResult{registration: agentRegistration, err: nil} + return AgentRegisterResult{registration: agentRegistration, commChannel: commChannel, err: nil} } func (s *AdminTestSuite) addAgent(publicId string, assignedPublicId string, serverRW io.ReadWriteCloser) (*agentConnection, error) { diff --git a/pkg/support/iowrappers/readwriteaddrcloser.go b/pkg/support/iowrappers/readwriteaddrcloser.go index a215521..b46df5d 100644 --- a/pkg/support/iowrappers/readwriteaddrcloser.go +++ b/pkg/support/iowrappers/readwriteaddrcloser.go @@ -10,3 +10,31 @@ type ReadWriteAddrCloser interface { RemoteAddr() net.Addr } + +type SimpleReadWriteAddrCloser struct { + rw io.ReadWriteCloser + addr net.Addr +} + +func NewSimpleReadWriteAddrCloser(rw io.ReadWriteCloser, addr net.Addr) *SimpleReadWriteAddrCloser { + return &SimpleReadWriteAddrCloser{ + rw: rw, + addr: addr, + } +} + +func (s *SimpleReadWriteAddrCloser) Read(p []byte) (int, error) { + return s.rw.Read(p) +} + +func (s *SimpleReadWriteAddrCloser) Write(p []byte) (int, error) { + return s.rw.Write(p) +} + +func (s *SimpleReadWriteAddrCloser) Close() error { + return s.rw.Close() +} + +func (s *SimpleReadWriteAddrCloser) RemoteAddr() net.Addr { + return s.addr +} diff --git a/pkg/testsupport/bitpipe.go b/pkg/testsupport/bitpipe.go index 0fd85f7..a2ebf0c 100644 --- a/pkg/testsupport/bitpipe.go +++ b/pkg/testsupport/bitpipe.go @@ -2,12 +2,12 @@ package testsupport import "git.wamblee.org/converge/pkg/support/iowrappers" -type dummyRemoteAddr string +type DummyRemoteAddr string -func (r dummyRemoteAddr) Network() string { +func (r DummyRemoteAddr) Network() string { return string(r) } -func (r dummyRemoteAddr) String() string { +func (r DummyRemoteAddr) String() string { return string(r) } diff --git a/pkg/testsupport/utils.go b/pkg/testsupport/utils.go index ca0c5f0..2a8e439 100644 --- a/pkg/testsupport/utils.go +++ b/pkg/testsupport/utils.go @@ -4,6 +4,7 @@ import ( "context" "git.wamblee.org/converge/pkg/support/pprof" "github.com/stretchr/testify/suite" + "io" "log" "net/http" "os" @@ -73,3 +74,17 @@ func CreateTestContext(ctx context.Context, timeout time.Duration) (context.Cont } return ctx, compositeCancelFunc } + +func AssertWriteData(s *suite.Suite, data string, writer io.Writer) { + n, err := writer.Write([]byte(data)) + s.Nil(err) + s.Equal(len(data), n) +} + +func AssertReadData(s *suite.Suite, data string, reader io.Reader) { + buf := make([]byte, len(data)*2) + n, err := reader.Read(buf) + s.Nil(err) + s.Equal(len(data), n) + s.Equal(data, string(buf[:n])) +}