test for connecting clients and bidirectional communication to agent.

Required lots of rework since the GOBChannel appeared to be reading
ahead of the data it actually needed. Now using more low-level IO
to send the clientId over to the agent instead.
This commit is contained in:
Erik Brakkee 2024-08-22 16:16:02 +02:00
parent d144a4d230
commit 28b2545163
10 changed files with 150 additions and 31 deletions

View File

@ -36,12 +36,12 @@ func (listener AgentListener) Accept() (net.Conn, error) {
return nil, err return nil, err
} }
clientInfo, err := ReceiveClientInfo(conn) clientId, err := ReceiveClientInfo(conn)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }
return NewLocalAddrHackConn(conn, clientInfo.ClientId), nil return NewLocalAddrHackConn(conn, clientId), nil
} }
func (listener AgentListener) Close() error { func (listener AgentListener) Close() error {

View File

@ -19,9 +19,7 @@ func (s *AgentServerTestSuite) Test_clientIdPassedAsLocalAddr() {
func() any { func() any {
connection, err := serverChannel.Session.OpenStream() connection, err := serverChannel.Session.OpenStream()
s.Nil(err) s.Nil(err)
gobChannel := NewGOBChannel(connection) err = SendClientInfo(connection, clientId)
clientInfo := ClientInfo{ClientId: clientId}
err = SendWithTimeout(gobChannel, clientInfo)
s.Nil(err) s.Nil(err)
return nil return nil
}) })

View File

@ -1,6 +1,7 @@
package comms package comms
import ( import (
"encoding/binary"
"fmt" "fmt"
"github.com/hashicorp/yamux" "github.com/hashicorp/yamux"
"io" "io"
@ -226,18 +227,31 @@ func CheckProtocolVersion(role Role, channel GOBChannel) error {
// decorates the yamux Session (which is a listener) and uses this connection to exchange some // decorates the yamux Session (which is a listener) and uses this connection to exchange some
// metadata before the connection is handed back to SSH. // metadata before the connection is handed back to SSH.
func SendClientInfo(conn io.ReadWriter, info ClientInfo) error { // Cannot use GOB for sending clientinfo since this involves mixing of buffered reads by
channel := NewGOBChannel(conn) // GOB with ather reads. Alternatively, we could wrap the GOB message and encode its length,
return SendWithTimeout(channel, info) // and then read the exacct number of bytes when decodeing. But since the clientInfo is just
// a string, this is easier.
func SendClientInfo(conn io.Writer, info string) error {
err := binary.Write(conn, binary.BigEndian, uint32(len(info)))
if err != nil {
return err
}
_, err = conn.Write([]byte(info))
return err
} }
func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) { func ReceiveClientInfo(conn io.Reader) (string, error) {
channel := NewGOBChannel(conn) var length uint32
clientInfo, err := ReceiveWithTimeout[ClientInfo](channel) err := binary.Read(conn, binary.BigEndian, &length)
if err != nil { if err != nil {
return ClientInfo{}, err return "", err
} }
return clientInfo, nil bytes := make([]byte, length)
_, err = io.ReadFull(conn, bytes)
if err != nil {
return "", err
}
return string(bytes), nil
} }
// message sent on the initial connection from server to agent to confirm the registration // message sent on the initial connection from server to agent to confirm the registration

View File

@ -1,6 +1,7 @@
package comms package comms
import ( import (
"bytes"
"context" "context"
"git.wamblee.org/converge/pkg/testsupport" "git.wamblee.org/converge/pkg/testsupport"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -57,6 +58,17 @@ func TestAgentServerTestSuite(t *testing.T) {
suite.Run(t, &AgentServerTestSuite{}) suite.Run(t, &AgentServerTestSuite{})
} }
func (s *AgentServerTestSuite) TestClientInfoEncodeDecode() {
for _, clientId := range []string{"abc", "djkdfadfha"} {
buf := &bytes.Buffer{}
err := SendClientInfo(buf, clientId)
s.Nil(err)
clientIdReceived, err := ReceiveClientInfo(buf)
s.Nil(err)
s.Equal(clientId, clientIdReceived)
}
}
func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) { func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) {
commChannels := testsupport.RunAndWait( commChannels := testsupport.RunAndWait(
&s.Suite, &s.Suite,

View File

@ -29,10 +29,6 @@ type EnvironmentInfo struct {
Shell string Shell string
} }
type ClientInfo struct {
ClientId string
}
type SessionInfo struct { type SessionInfo struct {
ClientId string ClientId string

View File

@ -166,16 +166,11 @@ func (admin *Admin) AddClient(publicId models.RendezVousId, clientConn iowrapper
return nil, err return nil, err
} }
log.Println("Successful websocket connection to agent") log.Println("Successful websocket connection to agent")
log.Println("Sending connection information to agent")
client := newClient(publicId, clientConn, agentConn, agent.Info.Guid) client := newClient(publicId, clientConn, agentConn, agent.Info.Guid)
// Before using this connection for SSH we use it to send client metadata to the // Before using this connection for SSH we use it to send client metadata to the
// agent // agent
err = comms.SendClientInfo(agentConn, comms.ClientInfo{ err = comms.SendClientInfo(agentConn, string(client.Info.ClientId))
ClientId: string(client.Info.ClientId),
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"git.wamblee.org/converge/pkg/comms" "git.wamblee.org/converge/pkg/comms"
"git.wamblee.org/converge/pkg/models" "git.wamblee.org/converge/pkg/models"
"git.wamblee.org/converge/pkg/support/iowrappers"
"git.wamblee.org/converge/pkg/testsupport" "git.wamblee.org/converge/pkg/testsupport"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"io" "io"
@ -66,7 +67,7 @@ func (s *AdminTestSuite) SetupTest() {
s.NotNil(rand.Read(s.hostKey)) s.NotNil(rand.Read(s.hostKey))
} }
func (suite *AdminTestSuite) TearDownTest() { func (s *AdminTestSuite) TearDownTest() {
} }
func TestAdminTestSuite(t *testing.T) { func TestAdminTestSuite(t *testing.T) {
@ -80,6 +81,7 @@ type AddAgentResult struct {
type AgentRegisterResult struct { type AgentRegisterResult struct {
registration comms.AgentRegistration registration comms.AgentRegistration
commChannel comms.CommChannel
err error err error
} }
@ -95,7 +97,7 @@ func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId stri
} }
}, },
func() any { func() any {
res := s.agentRegistration(assignedPublicId, agentRW) res := s.agentRegistration(agentRW)
if assignedPublicId != "" { if assignedPublicId != "" {
s.Nil(res.err) s.Nil(res.err)
s.True(res.registration.Ok) s.True(res.registration.Ok)
@ -153,7 +155,66 @@ func (s *AdminTestSuite) Test_agentDuplicateId() {
s.False(agentSideResult.registration.Ok) s.False(agentSideResult.registration.Ok)
} }
func (s *AdminTestSuite) agentRegistration(expectedPublicId string, agentRW io.ReadWriteCloser) AgentRegisterResult { func (s *AdminTestSuite) Test_connectClient() {
publicId := "abc"
serverRes, agentRes := s.agentRegisters(publicId, "abc")
s.Nil(serverRes.err)
s.Nil(agentRes.err)
serverToClientRW, clientToServerRW := s.createPipe()
data := "connect client test msg"
res := testsupport.RunAndWait(
&s.Suite,
func() any {
// server
clientConn, err := s.admin.AddClient(models.RendezVousId(publicId),
iowrappers.NewSimpleReadWriteAddrCloser(serverToClientRW, testsupport.DummyRemoteAddr("remoteaddr")))
s.Nil(err)
// Connection to agent over yamux
serverToAgentYamux := clientConn.agentConnection
// test by sending a message to the agent.
testsupport.AssertWriteData(&s.Suite, data, serverToAgentYamux)
return clientConn
},
func() any {
// agent
listener := comms.NewAgentListener(agentRes.commChannel.Session)
//.Connection from server over yamux
agentToServerYamux, err := listener.Accept()
s.Nil(err)
// Test by receiving a message from the server
testsupport.AssertReadData(&s.Suite, data, agentToServerYamux)
return agentToServerYamux
})
// Now we need to verify bi-directional communication between client and agent through the wserv
clientConn := res[0].(*clientConnection)
go clientConn.Synchronize()
agentToServerYamux := res[1].(io.ReadWriter)
data1 := "mytestdata"
data2 := "mytestdata-2"
testsupport.RunAndWait(
&s.Suite,
func() any {
testsupport.AssertWriteData(&s.Suite, data1, clientToServerRW)
testsupport.AssertReadData(&s.Suite, data2, agentToServerYamux)
return nil
},
func() any {
testsupport.AssertReadData(&s.Suite, data1, agentToServerYamux)
testsupport.AssertWriteData(&s.Suite, data2, clientToServerRW)
return nil
})
// will close the connections and as a result also th synchronize goroutine.
s.cancelFunc()
}
func (s *AdminTestSuite) agentRegistration(agentRW io.ReadWriteCloser) AgentRegisterResult {
// verify registration message received // verify registration message received
agentRegistration, err := comms.ReceiveRegistrationMessage(agentRW) agentRegistration, err := comms.ReceiveRegistrationMessage(agentRW)
if err != nil { if err != nil {
@ -164,7 +225,7 @@ func (s *AdminTestSuite) agentRegistration(expectedPublicId string, agentRW io.R
return AgentRegisterResult{registration: agentRegistration, err: err} return AgentRegisterResult{registration: agentRegistration, err: err}
} }
s.NotNil(commChannel) s.NotNil(commChannel)
return AgentRegisterResult{registration: agentRegistration, err: nil} return AgentRegisterResult{registration: agentRegistration, commChannel: commChannel, err: nil}
} }
func (s *AdminTestSuite) addAgent(publicId string, assignedPublicId string, serverRW io.ReadWriteCloser) (*agentConnection, error) { func (s *AdminTestSuite) addAgent(publicId string, assignedPublicId string, serverRW io.ReadWriteCloser) (*agentConnection, error) {

View File

@ -10,3 +10,31 @@ type ReadWriteAddrCloser interface {
RemoteAddr() net.Addr RemoteAddr() net.Addr
} }
type SimpleReadWriteAddrCloser struct {
rw io.ReadWriteCloser
addr net.Addr
}
func NewSimpleReadWriteAddrCloser(rw io.ReadWriteCloser, addr net.Addr) *SimpleReadWriteAddrCloser {
return &SimpleReadWriteAddrCloser{
rw: rw,
addr: addr,
}
}
func (s *SimpleReadWriteAddrCloser) Read(p []byte) (int, error) {
return s.rw.Read(p)
}
func (s *SimpleReadWriteAddrCloser) Write(p []byte) (int, error) {
return s.rw.Write(p)
}
func (s *SimpleReadWriteAddrCloser) Close() error {
return s.rw.Close()
}
func (s *SimpleReadWriteAddrCloser) RemoteAddr() net.Addr {
return s.addr
}

View File

@ -2,12 +2,12 @@ package testsupport
import "git.wamblee.org/converge/pkg/support/iowrappers" import "git.wamblee.org/converge/pkg/support/iowrappers"
type dummyRemoteAddr string type DummyRemoteAddr string
func (r dummyRemoteAddr) Network() string { func (r DummyRemoteAddr) Network() string {
return string(r) return string(r)
} }
func (r dummyRemoteAddr) String() string { func (r DummyRemoteAddr) String() string {
return string(r) return string(r)
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"git.wamblee.org/converge/pkg/support/pprof" "git.wamblee.org/converge/pkg/support/pprof"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"io"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -73,3 +74,17 @@ func CreateTestContext(ctx context.Context, timeout time.Duration) (context.Cont
} }
return ctx, compositeCancelFunc return ctx, compositeCancelFunc
} }
func AssertWriteData(s *suite.Suite, data string, writer io.Writer) {
n, err := writer.Write([]byte(data))
s.Nil(err)
s.Equal(len(data), n)
}
func AssertReadData(s *suite.Suite, data string, reader io.Reader) {
buf := make([]byte, len(data)*2)
n, err := reader.Read(buf)
s.Nil(err)
s.Equal(len(data), n)
s.Equal(data, string(buf[:n]))
}