diff --git a/pkg/comms/agentlistener.go b/pkg/comms/agentlistener.go index 95dabb1..6290475 100644 --- a/pkg/comms/agentlistener.go +++ b/pkg/comms/agentlistener.go @@ -36,12 +36,12 @@ func (listener AgentListener) Accept() (net.Conn, error) { return nil, err } - clientInfo, err := ReceiveClientInfo(conn) + clientId, err := ReceiveClientInfo(conn) if err != nil { conn.Close() return nil, err } - return NewLocalAddrHackConn(conn, clientInfo.ClientId), nil + return NewLocalAddrHackConn(conn, clientId), nil } func (listener AgentListener) Close() error { diff --git a/pkg/comms/agentlistener_test.go b/pkg/comms/agentlistener_test.go index 547d78b..abcea06 100644 --- a/pkg/comms/agentlistener_test.go +++ b/pkg/comms/agentlistener_test.go @@ -19,9 +19,7 @@ func (s *AgentServerTestSuite) Test_clientIdPassedAsLocalAddr() { func() any { connection, err := serverChannel.Session.OpenStream() s.Nil(err) - gobChannel := NewGOBChannel(connection) - clientInfo := ClientInfo{ClientId: clientId} - err = SendWithTimeout(gobChannel, clientInfo) + err = SendClientInfo(connection, clientId) s.Nil(err) return nil }) diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index 076fb4d..38ada3c 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -1,6 +1,7 @@ package comms import ( + "encoding/binary" "fmt" "github.com/hashicorp/yamux" "io" @@ -226,18 +227,31 @@ func CheckProtocolVersion(role Role, channel GOBChannel) error { // decorates the yamux Session (which is a listener) and uses this connection to exchange some // metadata before the connection is handed back to SSH. -func SendClientInfo(conn io.ReadWriter, info ClientInfo) error { - channel := NewGOBChannel(conn) - return SendWithTimeout(channel, info) +// Cannot use GOB for sending clientinfo since this involves mixing of buffered reads by +// GOB with ather reads. Alternatively, we could wrap the GOB message and encode its length, +// and then read the exacct number of bytes when decodeing. But since the clientInfo is just +// a string, this is easier. +func SendClientInfo(conn io.Writer, info string) error { + err := binary.Write(conn, binary.BigEndian, uint32(len(info))) + if err != nil { + return err + } + _, err = conn.Write([]byte(info)) + return err } -func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) { - channel := NewGOBChannel(conn) - clientInfo, err := ReceiveWithTimeout[ClientInfo](channel) +func ReceiveClientInfo(conn io.Reader) (string, error) { + var length uint32 + err := binary.Read(conn, binary.BigEndian, &length) if err != nil { - return ClientInfo{}, err + return "", err } - return clientInfo, nil + bytes := make([]byte, length) + _, err = io.ReadFull(conn, bytes) + if err != nil { + return "", err + } + return string(bytes), nil } // message sent on the initial connection from server to agent to confirm the registration diff --git a/pkg/comms/agentserver_test.go b/pkg/comms/agentserver_test.go index dc22798..2b1a5d7 100644 --- a/pkg/comms/agentserver_test.go +++ b/pkg/comms/agentserver_test.go @@ -1,6 +1,7 @@ package comms import ( + "bytes" "context" "git.wamblee.org/converge/pkg/testsupport" "github.com/stretchr/testify/suite" @@ -57,6 +58,17 @@ func TestAgentServerTestSuite(t *testing.T) { suite.Run(t, &AgentServerTestSuite{}) } +func (s *AgentServerTestSuite) TestClientInfoEncodeDecode() { + for _, clientId := range []string{"abc", "djkdfadfha"} { + buf := &bytes.Buffer{} + err := SendClientInfo(buf, clientId) + s.Nil(err) + clientIdReceived, err := ReceiveClientInfo(buf) + s.Nil(err) + s.Equal(clientId, clientIdReceived) + } +} + func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) { commChannels := testsupport.RunAndWait( &s.Suite, diff --git a/pkg/comms/events.go b/pkg/comms/events.go index 0a12365..ec049a1 100644 --- a/pkg/comms/events.go +++ b/pkg/comms/events.go @@ -29,10 +29,6 @@ type EnvironmentInfo struct { Shell string } -type ClientInfo struct { - ClientId string -} - type SessionInfo struct { ClientId string diff --git a/pkg/server/admin/admin.go b/pkg/server/admin/admin.go index 3e552c6..e74ac2d 100644 --- a/pkg/server/admin/admin.go +++ b/pkg/server/admin/admin.go @@ -166,16 +166,11 @@ func (admin *Admin) AddClient(publicId models.RendezVousId, clientConn iowrapper return nil, err } log.Println("Successful websocket connection to agent") - - log.Println("Sending connection information to agent") - client := newClient(publicId, clientConn, agentConn, agent.Info.Guid) // Before using this connection for SSH we use it to send client metadata to the // agent - err = comms.SendClientInfo(agentConn, comms.ClientInfo{ - ClientId: string(client.Info.ClientId), - }) + err = comms.SendClientInfo(agentConn, string(client.Info.ClientId)) if err != nil { return nil, err } diff --git a/pkg/server/admin/admin_test.go b/pkg/server/admin/admin_test.go index e2f3c38..31942c0 100644 --- a/pkg/server/admin/admin_test.go +++ b/pkg/server/admin/admin_test.go @@ -6,6 +6,7 @@ import ( "fmt" "git.wamblee.org/converge/pkg/comms" "git.wamblee.org/converge/pkg/models" + "git.wamblee.org/converge/pkg/support/iowrappers" "git.wamblee.org/converge/pkg/testsupport" "github.com/stretchr/testify/suite" "io" @@ -66,7 +67,7 @@ func (s *AdminTestSuite) SetupTest() { s.NotNil(rand.Read(s.hostKey)) } -func (suite *AdminTestSuite) TearDownTest() { +func (s *AdminTestSuite) TearDownTest() { } func TestAdminTestSuite(t *testing.T) { @@ -80,6 +81,7 @@ type AddAgentResult struct { type AgentRegisterResult struct { registration comms.AgentRegistration + commChannel comms.CommChannel err error } @@ -95,7 +97,7 @@ func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId stri } }, func() any { - res := s.agentRegistration(assignedPublicId, agentRW) + res := s.agentRegistration(agentRW) if assignedPublicId != "" { s.Nil(res.err) s.True(res.registration.Ok) @@ -148,12 +150,71 @@ func (s *AdminTestSuite) Test_agentDuplicateId() { } res, agentSideResult := s.agentRegisters("abc", "") s.NotNil(res.err) - // verify it is the correct error an dnot an id mismatch. + // verify it is the correct error and not an id mismatch. s.True(strings.Contains(res.err.Error(), "could not allocate a new unique id")) s.False(agentSideResult.registration.Ok) } -func (s *AdminTestSuite) agentRegistration(expectedPublicId string, agentRW io.ReadWriteCloser) AgentRegisterResult { +func (s *AdminTestSuite) Test_connectClient() { + publicId := "abc" + serverRes, agentRes := s.agentRegisters(publicId, "abc") + s.Nil(serverRes.err) + s.Nil(agentRes.err) + + serverToClientRW, clientToServerRW := s.createPipe() + + data := "connect client test msg" + res := testsupport.RunAndWait( + &s.Suite, + func() any { + // server + clientConn, err := s.admin.AddClient(models.RendezVousId(publicId), + iowrappers.NewSimpleReadWriteAddrCloser(serverToClientRW, testsupport.DummyRemoteAddr("remoteaddr"))) + s.Nil(err) + // Connection to agent over yamux + serverToAgentYamux := clientConn.agentConnection + // test by sending a message to the agent. + testsupport.AssertWriteData(&s.Suite, data, serverToAgentYamux) + return clientConn + }, + func() any { + // agent + listener := comms.NewAgentListener(agentRes.commChannel.Session) + //.Connection from server over yamux + agentToServerYamux, err := listener.Accept() + s.Nil(err) + // Test by receiving a message from the server + testsupport.AssertReadData(&s.Suite, data, agentToServerYamux) + return agentToServerYamux + }) + + // Now we need to verify bi-directional communication between client and agent through the wserv + + clientConn := res[0].(*clientConnection) + go clientConn.Synchronize() + + agentToServerYamux := res[1].(io.ReadWriter) + + data1 := "mytestdata" + data2 := "mytestdata-2" + testsupport.RunAndWait( + &s.Suite, + func() any { + testsupport.AssertWriteData(&s.Suite, data1, clientToServerRW) + testsupport.AssertReadData(&s.Suite, data2, agentToServerYamux) + return nil + }, + func() any { + testsupport.AssertReadData(&s.Suite, data1, agentToServerYamux) + testsupport.AssertWriteData(&s.Suite, data2, clientToServerRW) + return nil + }) + + // will close the connections and as a result also th synchronize goroutine. + s.cancelFunc() +} + +func (s *AdminTestSuite) agentRegistration(agentRW io.ReadWriteCloser) AgentRegisterResult { // verify registration message received agentRegistration, err := comms.ReceiveRegistrationMessage(agentRW) if err != nil { @@ -164,7 +225,7 @@ func (s *AdminTestSuite) agentRegistration(expectedPublicId string, agentRW io.R return AgentRegisterResult{registration: agentRegistration, err: err} } s.NotNil(commChannel) - return AgentRegisterResult{registration: agentRegistration, err: nil} + return AgentRegisterResult{registration: agentRegistration, commChannel: commChannel, err: nil} } func (s *AdminTestSuite) addAgent(publicId string, assignedPublicId string, serverRW io.ReadWriteCloser) (*agentConnection, error) { diff --git a/pkg/support/iowrappers/readwriteaddrcloser.go b/pkg/support/iowrappers/readwriteaddrcloser.go index a215521..b46df5d 100644 --- a/pkg/support/iowrappers/readwriteaddrcloser.go +++ b/pkg/support/iowrappers/readwriteaddrcloser.go @@ -10,3 +10,31 @@ type ReadWriteAddrCloser interface { RemoteAddr() net.Addr } + +type SimpleReadWriteAddrCloser struct { + rw io.ReadWriteCloser + addr net.Addr +} + +func NewSimpleReadWriteAddrCloser(rw io.ReadWriteCloser, addr net.Addr) *SimpleReadWriteAddrCloser { + return &SimpleReadWriteAddrCloser{ + rw: rw, + addr: addr, + } +} + +func (s *SimpleReadWriteAddrCloser) Read(p []byte) (int, error) { + return s.rw.Read(p) +} + +func (s *SimpleReadWriteAddrCloser) Write(p []byte) (int, error) { + return s.rw.Write(p) +} + +func (s *SimpleReadWriteAddrCloser) Close() error { + return s.rw.Close() +} + +func (s *SimpleReadWriteAddrCloser) RemoteAddr() net.Addr { + return s.addr +} diff --git a/pkg/testsupport/bitpipe.go b/pkg/testsupport/bitpipe.go index 0fd85f7..a2ebf0c 100644 --- a/pkg/testsupport/bitpipe.go +++ b/pkg/testsupport/bitpipe.go @@ -2,12 +2,12 @@ package testsupport import "git.wamblee.org/converge/pkg/support/iowrappers" -type dummyRemoteAddr string +type DummyRemoteAddr string -func (r dummyRemoteAddr) Network() string { +func (r DummyRemoteAddr) Network() string { return string(r) } -func (r dummyRemoteAddr) String() string { +func (r DummyRemoteAddr) String() string { return string(r) } diff --git a/pkg/testsupport/utils.go b/pkg/testsupport/utils.go index ca0c5f0..2a8e439 100644 --- a/pkg/testsupport/utils.go +++ b/pkg/testsupport/utils.go @@ -4,6 +4,7 @@ import ( "context" "git.wamblee.org/converge/pkg/support/pprof" "github.com/stretchr/testify/suite" + "io" "log" "net/http" "os" @@ -73,3 +74,17 @@ func CreateTestContext(ctx context.Context, timeout time.Duration) (context.Cont } return ctx, compositeCancelFunc } + +func AssertWriteData(s *suite.Suite, data string, writer io.Writer) { + n, err := writer.Write([]byte(data)) + s.Nil(err) + s.Equal(len(data), n) +} + +func AssertReadData(s *suite.Suite, data string, reader io.Reader) { + buf := make([]byte, len(data)*2) + n, err := reader.Read(buf) + s.Nil(err) + s.Equal(len(data), n) + s.Equal(data, string(buf[:n])) +}