From 7e062f577773874882bc16540b9ab88816a9e7b5 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Sat, 24 Aug 2024 20:34:27 +0200 Subject: [PATCH] reintroduced ClientInfo because it does appear to work. Most likely some error elsewhere caused it not to work previously --- pkg/comms/agentlistener.go | 2 +- pkg/comms/agentlistener_test.go | 2 +- pkg/comms/agentserver.go | 30 ++++++++---------------------- pkg/comms/agentserver_test.go | 4 ++-- pkg/comms/events.go | 4 ++++ pkg/server/admin/admin.go | 2 +- 6 files changed, 17 insertions(+), 27 deletions(-) diff --git a/pkg/comms/agentlistener.go b/pkg/comms/agentlistener.go index 71f7db8..7691c5e 100644 --- a/pkg/comms/agentlistener.go +++ b/pkg/comms/agentlistener.go @@ -41,7 +41,7 @@ func (listener AgentListener) Accept() (net.Conn, error) { conn.Close() return nil, err } - conn = NewLocalAddrHackConn(conn, clientId) + conn = NewLocalAddrHackConn(conn, clientId.ClientId) return conn, nil } diff --git a/pkg/comms/agentlistener_test.go b/pkg/comms/agentlistener_test.go index abcea06..762ea43 100644 --- a/pkg/comms/agentlistener_test.go +++ b/pkg/comms/agentlistener_test.go @@ -19,7 +19,7 @@ func (s *AgentServerTestSuite) Test_clientIdPassedAsLocalAddr() { func() any { connection, err := serverChannel.Session.OpenStream() s.Nil(err) - err = SendClientInfo(connection, clientId) + err = SendClientInfo(connection, ClientInfo{ClientId: clientId}) s.Nil(err) return nil }) diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index 38ada3c..076fb4d 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -1,7 +1,6 @@ package comms import ( - "encoding/binary" "fmt" "github.com/hashicorp/yamux" "io" @@ -227,31 +226,18 @@ func CheckProtocolVersion(role Role, channel GOBChannel) error { // decorates the yamux Session (which is a listener) and uses this connection to exchange some // metadata before the connection is handed back to SSH. -// Cannot use GOB for sending clientinfo since this involves mixing of buffered reads by -// GOB with ather reads. Alternatively, we could wrap the GOB message and encode its length, -// and then read the exacct number of bytes when decodeing. But since the clientInfo is just -// a string, this is easier. -func SendClientInfo(conn io.Writer, info string) error { - err := binary.Write(conn, binary.BigEndian, uint32(len(info))) - if err != nil { - return err - } - _, err = conn.Write([]byte(info)) - return err +func SendClientInfo(conn io.ReadWriter, info ClientInfo) error { + channel := NewGOBChannel(conn) + return SendWithTimeout(channel, info) } -func ReceiveClientInfo(conn io.Reader) (string, error) { - var length uint32 - err := binary.Read(conn, binary.BigEndian, &length) +func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) { + channel := NewGOBChannel(conn) + clientInfo, err := ReceiveWithTimeout[ClientInfo](channel) if err != nil { - return "", err + return ClientInfo{}, err } - bytes := make([]byte, length) - _, err = io.ReadFull(conn, bytes) - if err != nil { - return "", err - } - return string(bytes), nil + return clientInfo, nil } // message sent on the initial connection from server to agent to confirm the registration diff --git a/pkg/comms/agentserver_test.go b/pkg/comms/agentserver_test.go index a08f0e2..48a98a4 100644 --- a/pkg/comms/agentserver_test.go +++ b/pkg/comms/agentserver_test.go @@ -64,11 +64,11 @@ func TestAgentServerTestSuite(t *testing.T) { func (s *AgentServerTestSuite) TestClientInfoEncodeDecode() { for _, clientId := range []string{"abc", "djkdfadfha"} { buf := &bytes.Buffer{} - err := SendClientInfo(buf, clientId) + err := SendClientInfo(buf, ClientInfo{ClientId: clientId}) s.Nil(err) clientIdReceived, err := ReceiveClientInfo(buf) s.Nil(err) - s.Equal(clientId, clientIdReceived) + s.Equal(clientId, clientIdReceived.ClientId) } } diff --git a/pkg/comms/events.go b/pkg/comms/events.go index ec049a1..0a12365 100644 --- a/pkg/comms/events.go +++ b/pkg/comms/events.go @@ -29,6 +29,10 @@ type EnvironmentInfo struct { Shell string } +type ClientInfo struct { + ClientId string +} + type SessionInfo struct { ClientId string diff --git a/pkg/server/admin/admin.go b/pkg/server/admin/admin.go index 1a3f0c0..542c172 100644 --- a/pkg/server/admin/admin.go +++ b/pkg/server/admin/admin.go @@ -180,7 +180,7 @@ func (admin *Admin) AddClient(publicId models.RendezVousId, clientConn iowrapper // Before using this connection for SSH we use it to send client metadata to the // agent - err = comms.SendClientInfo(agentConn, string(client.Info.ClientId)) + err = comms.SendClientInfo(agentConn, comms.ClientInfo{ClientId: string(client.Info.ClientId)}) if err != nil { return nil, err }