test for connecting clients and bidirectional communication to agent.
Required lots of rework since the GOBChannel appeared to be reading ahead of the data it actually needed. Now using more low-level IO to send the clientId over to the agent instead.
This commit is contained in:
parent
d144a4d230
commit
28b2545163
@ -36,12 +36,12 @@ func (listener AgentListener) Accept() (net.Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
clientInfo, err := ReceiveClientInfo(conn)
|
clientId, err := ReceiveClientInfo(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewLocalAddrHackConn(conn, clientInfo.ClientId), nil
|
return NewLocalAddrHackConn(conn, clientId), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (listener AgentListener) Close() error {
|
func (listener AgentListener) Close() error {
|
||||||
|
@ -19,9 +19,7 @@ func (s *AgentServerTestSuite) Test_clientIdPassedAsLocalAddr() {
|
|||||||
func() any {
|
func() any {
|
||||||
connection, err := serverChannel.Session.OpenStream()
|
connection, err := serverChannel.Session.OpenStream()
|
||||||
s.Nil(err)
|
s.Nil(err)
|
||||||
gobChannel := NewGOBChannel(connection)
|
err = SendClientInfo(connection, clientId)
|
||||||
clientInfo := ClientInfo{ClientId: clientId}
|
|
||||||
err = SendWithTimeout(gobChannel, clientInfo)
|
|
||||||
s.Nil(err)
|
s.Nil(err)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package comms
|
package comms
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/hashicorp/yamux"
|
"github.com/hashicorp/yamux"
|
||||||
"io"
|
"io"
|
||||||
@ -226,18 +227,31 @@ func CheckProtocolVersion(role Role, channel GOBChannel) error {
|
|||||||
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
|
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
|
||||||
// metadata before the connection is handed back to SSH.
|
// metadata before the connection is handed back to SSH.
|
||||||
|
|
||||||
func SendClientInfo(conn io.ReadWriter, info ClientInfo) error {
|
// Cannot use GOB for sending clientinfo since this involves mixing of buffered reads by
|
||||||
channel := NewGOBChannel(conn)
|
// GOB with ather reads. Alternatively, we could wrap the GOB message and encode its length,
|
||||||
return SendWithTimeout(channel, info)
|
// and then read the exacct number of bytes when decodeing. But since the clientInfo is just
|
||||||
|
// a string, this is easier.
|
||||||
|
func SendClientInfo(conn io.Writer, info string) error {
|
||||||
|
err := binary.Write(conn, binary.BigEndian, uint32(len(info)))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = conn.Write([]byte(info))
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) {
|
func ReceiveClientInfo(conn io.Reader) (string, error) {
|
||||||
channel := NewGOBChannel(conn)
|
var length uint32
|
||||||
clientInfo, err := ReceiveWithTimeout[ClientInfo](channel)
|
err := binary.Read(conn, binary.BigEndian, &length)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ClientInfo{}, err
|
return "", err
|
||||||
}
|
}
|
||||||
return clientInfo, nil
|
bytes := make([]byte, length)
|
||||||
|
_, err = io.ReadFull(conn, bytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// message sent on the initial connection from server to agent to confirm the registration
|
// message sent on the initial connection from server to agent to confirm the registration
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package comms
|
package comms
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"git.wamblee.org/converge/pkg/testsupport"
|
"git.wamblee.org/converge/pkg/testsupport"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@ -57,6 +58,17 @@ func TestAgentServerTestSuite(t *testing.T) {
|
|||||||
suite.Run(t, &AgentServerTestSuite{})
|
suite.Run(t, &AgentServerTestSuite{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AgentServerTestSuite) TestClientInfoEncodeDecode() {
|
||||||
|
for _, clientId := range []string{"abc", "djkdfadfha"} {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
err := SendClientInfo(buf, clientId)
|
||||||
|
s.Nil(err)
|
||||||
|
clientIdReceived, err := ReceiveClientInfo(buf)
|
||||||
|
s.Nil(err)
|
||||||
|
s.Equal(clientId, clientIdReceived)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) {
|
func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) {
|
||||||
commChannels := testsupport.RunAndWait(
|
commChannels := testsupport.RunAndWait(
|
||||||
&s.Suite,
|
&s.Suite,
|
||||||
|
@ -29,10 +29,6 @@ type EnvironmentInfo struct {
|
|||||||
Shell string
|
Shell string
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientInfo struct {
|
|
||||||
ClientId string
|
|
||||||
}
|
|
||||||
|
|
||||||
type SessionInfo struct {
|
type SessionInfo struct {
|
||||||
ClientId string
|
ClientId string
|
||||||
|
|
||||||
|
@ -166,16 +166,11 @@ func (admin *Admin) AddClient(publicId models.RendezVousId, clientConn iowrapper
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
log.Println("Successful websocket connection to agent")
|
log.Println("Successful websocket connection to agent")
|
||||||
|
|
||||||
log.Println("Sending connection information to agent")
|
|
||||||
|
|
||||||
client := newClient(publicId, clientConn, agentConn, agent.Info.Guid)
|
client := newClient(publicId, clientConn, agentConn, agent.Info.Guid)
|
||||||
|
|
||||||
// Before using this connection for SSH we use it to send client metadata to the
|
// Before using this connection for SSH we use it to send client metadata to the
|
||||||
// agent
|
// agent
|
||||||
err = comms.SendClientInfo(agentConn, comms.ClientInfo{
|
err = comms.SendClientInfo(agentConn, string(client.Info.ClientId))
|
||||||
ClientId: string(client.Info.ClientId),
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"git.wamblee.org/converge/pkg/comms"
|
"git.wamblee.org/converge/pkg/comms"
|
||||||
"git.wamblee.org/converge/pkg/models"
|
"git.wamblee.org/converge/pkg/models"
|
||||||
|
"git.wamblee.org/converge/pkg/support/iowrappers"
|
||||||
"git.wamblee.org/converge/pkg/testsupport"
|
"git.wamblee.org/converge/pkg/testsupport"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
"io"
|
"io"
|
||||||
@ -66,7 +67,7 @@ func (s *AdminTestSuite) SetupTest() {
|
|||||||
s.NotNil(rand.Read(s.hostKey))
|
s.NotNil(rand.Read(s.hostKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *AdminTestSuite) TearDownTest() {
|
func (s *AdminTestSuite) TearDownTest() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAdminTestSuite(t *testing.T) {
|
func TestAdminTestSuite(t *testing.T) {
|
||||||
@ -80,6 +81,7 @@ type AddAgentResult struct {
|
|||||||
|
|
||||||
type AgentRegisterResult struct {
|
type AgentRegisterResult struct {
|
||||||
registration comms.AgentRegistration
|
registration comms.AgentRegistration
|
||||||
|
commChannel comms.CommChannel
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,7 +97,7 @@ func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId stri
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
func() any {
|
func() any {
|
||||||
res := s.agentRegistration(assignedPublicId, agentRW)
|
res := s.agentRegistration(agentRW)
|
||||||
if assignedPublicId != "" {
|
if assignedPublicId != "" {
|
||||||
s.Nil(res.err)
|
s.Nil(res.err)
|
||||||
s.True(res.registration.Ok)
|
s.True(res.registration.Ok)
|
||||||
@ -153,7 +155,66 @@ func (s *AdminTestSuite) Test_agentDuplicateId() {
|
|||||||
s.False(agentSideResult.registration.Ok)
|
s.False(agentSideResult.registration.Ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AdminTestSuite) agentRegistration(expectedPublicId string, agentRW io.ReadWriteCloser) AgentRegisterResult {
|
func (s *AdminTestSuite) Test_connectClient() {
|
||||||
|
publicId := "abc"
|
||||||
|
serverRes, agentRes := s.agentRegisters(publicId, "abc")
|
||||||
|
s.Nil(serverRes.err)
|
||||||
|
s.Nil(agentRes.err)
|
||||||
|
|
||||||
|
serverToClientRW, clientToServerRW := s.createPipe()
|
||||||
|
|
||||||
|
data := "connect client test msg"
|
||||||
|
res := testsupport.RunAndWait(
|
||||||
|
&s.Suite,
|
||||||
|
func() any {
|
||||||
|
// server
|
||||||
|
clientConn, err := s.admin.AddClient(models.RendezVousId(publicId),
|
||||||
|
iowrappers.NewSimpleReadWriteAddrCloser(serverToClientRW, testsupport.DummyRemoteAddr("remoteaddr")))
|
||||||
|
s.Nil(err)
|
||||||
|
// Connection to agent over yamux
|
||||||
|
serverToAgentYamux := clientConn.agentConnection
|
||||||
|
// test by sending a message to the agent.
|
||||||
|
testsupport.AssertWriteData(&s.Suite, data, serverToAgentYamux)
|
||||||
|
return clientConn
|
||||||
|
},
|
||||||
|
func() any {
|
||||||
|
// agent
|
||||||
|
listener := comms.NewAgentListener(agentRes.commChannel.Session)
|
||||||
|
//.Connection from server over yamux
|
||||||
|
agentToServerYamux, err := listener.Accept()
|
||||||
|
s.Nil(err)
|
||||||
|
// Test by receiving a message from the server
|
||||||
|
testsupport.AssertReadData(&s.Suite, data, agentToServerYamux)
|
||||||
|
return agentToServerYamux
|
||||||
|
})
|
||||||
|
|
||||||
|
// Now we need to verify bi-directional communication between client and agent through the wserv
|
||||||
|
|
||||||
|
clientConn := res[0].(*clientConnection)
|
||||||
|
go clientConn.Synchronize()
|
||||||
|
|
||||||
|
agentToServerYamux := res[1].(io.ReadWriter)
|
||||||
|
|
||||||
|
data1 := "mytestdata"
|
||||||
|
data2 := "mytestdata-2"
|
||||||
|
testsupport.RunAndWait(
|
||||||
|
&s.Suite,
|
||||||
|
func() any {
|
||||||
|
testsupport.AssertWriteData(&s.Suite, data1, clientToServerRW)
|
||||||
|
testsupport.AssertReadData(&s.Suite, data2, agentToServerYamux)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func() any {
|
||||||
|
testsupport.AssertReadData(&s.Suite, data1, agentToServerYamux)
|
||||||
|
testsupport.AssertWriteData(&s.Suite, data2, clientToServerRW)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// will close the connections and as a result also th synchronize goroutine.
|
||||||
|
s.cancelFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AdminTestSuite) agentRegistration(agentRW io.ReadWriteCloser) AgentRegisterResult {
|
||||||
// verify registration message received
|
// verify registration message received
|
||||||
agentRegistration, err := comms.ReceiveRegistrationMessage(agentRW)
|
agentRegistration, err := comms.ReceiveRegistrationMessage(agentRW)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -164,7 +225,7 @@ func (s *AdminTestSuite) agentRegistration(expectedPublicId string, agentRW io.R
|
|||||||
return AgentRegisterResult{registration: agentRegistration, err: err}
|
return AgentRegisterResult{registration: agentRegistration, err: err}
|
||||||
}
|
}
|
||||||
s.NotNil(commChannel)
|
s.NotNil(commChannel)
|
||||||
return AgentRegisterResult{registration: agentRegistration, err: nil}
|
return AgentRegisterResult{registration: agentRegistration, commChannel: commChannel, err: nil}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AdminTestSuite) addAgent(publicId string, assignedPublicId string, serverRW io.ReadWriteCloser) (*agentConnection, error) {
|
func (s *AdminTestSuite) addAgent(publicId string, assignedPublicId string, serverRW io.ReadWriteCloser) (*agentConnection, error) {
|
||||||
|
@ -10,3 +10,31 @@ type ReadWriteAddrCloser interface {
|
|||||||
|
|
||||||
RemoteAddr() net.Addr
|
RemoteAddr() net.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SimpleReadWriteAddrCloser struct {
|
||||||
|
rw io.ReadWriteCloser
|
||||||
|
addr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSimpleReadWriteAddrCloser(rw io.ReadWriteCloser, addr net.Addr) *SimpleReadWriteAddrCloser {
|
||||||
|
return &SimpleReadWriteAddrCloser{
|
||||||
|
rw: rw,
|
||||||
|
addr: addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SimpleReadWriteAddrCloser) Read(p []byte) (int, error) {
|
||||||
|
return s.rw.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SimpleReadWriteAddrCloser) Write(p []byte) (int, error) {
|
||||||
|
return s.rw.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SimpleReadWriteAddrCloser) Close() error {
|
||||||
|
return s.rw.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SimpleReadWriteAddrCloser) RemoteAddr() net.Addr {
|
||||||
|
return s.addr
|
||||||
|
}
|
||||||
|
@ -2,12 +2,12 @@ package testsupport
|
|||||||
|
|
||||||
import "git.wamblee.org/converge/pkg/support/iowrappers"
|
import "git.wamblee.org/converge/pkg/support/iowrappers"
|
||||||
|
|
||||||
type dummyRemoteAddr string
|
type DummyRemoteAddr string
|
||||||
|
|
||||||
func (r dummyRemoteAddr) Network() string {
|
func (r DummyRemoteAddr) Network() string {
|
||||||
return string(r)
|
return string(r)
|
||||||
}
|
}
|
||||||
func (r dummyRemoteAddr) String() string {
|
func (r DummyRemoteAddr) String() string {
|
||||||
return string(r)
|
return string(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"git.wamblee.org/converge/pkg/support/pprof"
|
"git.wamblee.org/converge/pkg/support/pprof"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@ -73,3 +74,17 @@ func CreateTestContext(ctx context.Context, timeout time.Duration) (context.Cont
|
|||||||
}
|
}
|
||||||
return ctx, compositeCancelFunc
|
return ctx, compositeCancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AssertWriteData(s *suite.Suite, data string, writer io.Writer) {
|
||||||
|
n, err := writer.Write([]byte(data))
|
||||||
|
s.Nil(err)
|
||||||
|
s.Equal(len(data), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertReadData(s *suite.Suite, data string, reader io.Reader) {
|
||||||
|
buf := make([]byte, len(data)*2)
|
||||||
|
n, err := reader.Read(buf)
|
||||||
|
s.Nil(err)
|
||||||
|
s.Equal(len(data), n)
|
||||||
|
s.Equal(data, string(buf[:n]))
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user