converge/pkg/comms/agentserver_test.go
Erik Brakkee 7d25f39f5b test for connecting clients and bidirectional communication to agent.
Required lots of rework since the GOBChannel appeared to be reading
ahead of the data it actually needed. Now using more low-level IO
to send the clientId over to the agent instead.
2024-08-22 16:16:02 +02:00

293 lines
7.6 KiB
Go

package comms
import (
"bytes"
"context"
"git.wamblee.org/converge/pkg/testsupport"
"github.com/stretchr/testify/suite"
"io"
"log"
"math/rand"
"net/http"
"strings"
"testing"
"time"
)
type AgentServerTestSuite struct {
suite.Suite
ctx context.Context
cancelFunc context.CancelFunc
pprofServer *http.Server
agentReadWriter io.ReadWriteCloser
serverReadWriter io.ReadWriteCloser
}
func (s *AgentServerTestSuite) SetupSuite() {
s.pprofServer = testsupport.StartPprof("")
}
func (s *AgentServerTestSuite) TearDownSuite() {
testsupport.StopPprof(s.ctx, s.pprofServer)
}
func (s *AgentServerTestSuite) SetupTest() {
ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second)
s.ctx = ctx
s.cancelFunc = cancelFunc
// Could have also used net.Pipe but net.Pipe uses synchronous communication
// by default and the bitpipe implementation can become asynchronous when
// a channels ize > 0 is passed in. Also the test utility respects the context
// so also deals with cancellation much better than net.Pipe.
bitpipe := testsupport.NewInmemoryConnection(s.ctx, "inmemory", 10)
agentReadWriter := bitpipe.Front()
serverReadWriter := bitpipe.Back()
s.agentReadWriter = agentReadWriter
s.serverReadWriter = serverReadWriter
}
func (suite *AgentServerTestSuite) TearDownTest() {
agentProtocolVersion = PROTOCOL_VERSION
serverProtocolVersion = PROTOCOL_VERSION
}
func TestAgentServerTestSuite(t *testing.T) {
suite.Run(t, &AgentServerTestSuite{})
}
func (s *AgentServerTestSuite) TestClientInfoEncodeDecode() {
for _, clientId := range []string{"abc", "djkdfadfha"} {
buf := &bytes.Buffer{}
err := SendClientInfo(buf, clientId)
s.Nil(err)
clientIdReceived, err := ReceiveClientInfo(buf)
s.Nil(err)
s.Equal(clientId, clientIdReceived)
}
}
func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) {
commChannels := testsupport.RunAndWait(
&s.Suite,
func() any {
log.Println("Agent initializing")
commChannel, err := NewCommChannel(Agent, s.agentReadWriter)
s.Nil(err)
return commChannel
},
func() any {
log.Println("Server initializing")
commChannel, err := NewCommChannel(ConvergeServer, s.serverReadWriter)
s.Nil(err)
return commChannel
},
)
s.Equal(2, len(commChannels))
agentCommChannel := commChannels[0].(CommChannel)
serverCommChannel := commChannels[1].(CommChannel)
return agentCommChannel, serverCommChannel
}
func (s *AgentServerTestSuite) TestNewCommChannel() {
// Setup Comm channel
agentCommChannel, serverCommChannel := s.createCommChannel()
// verify the side channel is working by sending an object
testsupport.RunAndWait(
&s.Suite,
func() any {
protocolVersion := ProtocolVersion{Version: 10}
err := SendWithTimeout[ProtocolVersion](agentCommChannel.SideChannel, protocolVersion)
s.Nil(err)
log.Printf("Sent one message %v", protocolVersion)
return nil
},
func() any {
protocolVersion, err := ReceiveWithTimeout[ProtocolVersion](serverCommChannel.SideChannel)
s.Nil(err)
log.Printf("Received one message %v", protocolVersion)
return nil
},
)
log.Printf("%v %v", agentCommChannel, serverCommChannel)
}
func (s *AgentServerTestSuite) Test_ConnectThroughYamux() {
agentCommChannel, serverCommChannel := s.createCommChannel()
dataAgentToServer := "hello"
dataServerToAgent := "bye"
testsupport.RunAndWait(
&s.Suite,
func() any {
conn, err := agentCommChannel.Session.OpenStream()
s.Nil(err)
n, err := conn.Write([]byte(dataAgentToServer))
s.Nil(err)
s.Equal(len(dataAgentToServer), n)
buf := make([]byte, len(dataServerToAgent))
n, err = conn.Read(buf)
s.Nil(err)
s.Equal(len(dataServerToAgent), n)
s.Equal([]byte(dataServerToAgent), buf[:n])
return nil
},
func() any {
conn, err := serverCommChannel.Session.Accept()
s.Nil(err)
buf := make([]byte, len(dataAgentToServer))
n, err := conn.Read(buf)
s.Nil(err)
s.Equal(len(dataAgentToServer), n)
s.Equal([]byte(dataAgentToServer), buf[:n])
n, err = conn.Write([]byte(dataServerToAgent))
s.Nil(err)
s.Equal(len(dataServerToAgent), n)
return nil
})
}
func (s *AgentServerTestSuite) Test_Initialization() {
serverTestId := rand.Int()
agentShell := "agentshell"
testsupport.RunAndWait(
&s.Suite,
func() any {
serverInfo, err := AgentInitialization(s.agentReadWriter,
NewEnvironmentInfo(agentShell))
s.Nil(err)
s.Equal(serverTestId, serverInfo.TestId)
return nil
},
func() any {
serverInfo := ServerInfo{TestId: serverTestId}
environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
s.Nil(err)
s.Equal(agentShell, environmentInfo.Shell)
return nil
})
}
func (s *AgentServerTestSuite) Test_InitializationProtocolVersionMismatch() {
serverProtocolVersion++
testsupport.RunAndWait(
&s.Suite,
func() any {
serverInfo, err := AgentInitialization(s.agentReadWriter,
NewEnvironmentInfo("myshell"))
s.NotNil(err)
s.True(strings.Contains(strings.ToLower(err.Error()), "protocol"))
s.Equal(ServerInfo{}, serverInfo)
return nil
},
func() any {
serverInfo := ServerInfo{TestId: 1000}
environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
s.NotNil(err)
s.True(strings.Contains(strings.ToLower(err.Error()), "protocol"))
s.Equal(EnvironmentInfo{}, environmentInfo)
return nil
})
}
func (s *AgentServerTestSuite) Test_InitializationAgentConnectionClosed() {
s.agentReadWriter.Close()
s.checkInitializationFailure()
}
func (s *AgentServerTestSuite) Test_InitializationServerConnectionClosed() {
s.serverReadWriter.Close()
s.checkInitializationFailure()
}
func (s *AgentServerTestSuite) checkInitializationFailure() []any {
return testsupport.RunAndWait(
&s.Suite,
func() any {
serverInfo, err := AgentInitialization(s.agentReadWriter,
NewEnvironmentInfo("myshell"))
s.NotNil(err)
s.Equal(ServerInfo{}, serverInfo)
return nil
},
func() any {
serverInfo := ServerInfo{TestId: 1000}
environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
s.NotNil(err)
s.Equal(EnvironmentInfo{}, environmentInfo)
return nil
})
}
// TODO:
// Tests when connection is close from agent and from server: verify error is returned
func (s *AgentServerTestSuite) Test_ListenForAgentEvents() {
agentEvents := []any{
NewEnvironmentInfo("myshell"),
NewSessionInfo("1", "sftp"),
NewExpiryTimeUpdate(time.Now().Add(1 * time.Minute)),
HeartBeat{},
}
const nevents = 100
eventTypesSent := make([]int, nevents, nevents)
testsupport.RunAndWait(
&s.Suite,
func() any {
channel := NewGOBChannel(s.agentReadWriter)
for i := range nevents {
ievent := rand.Int() % len(agentEvents)
eventTypesSent[i] = ievent
event := ConvergeMessage{
Value: agentEvents[ievent],
}
err := SendWithTimeout[ConvergeMessage](channel, event)
s.Nil(err)
}
// pending events will still be sent.
s.agentReadWriter.Close()
return nil
},
func() any {
eventTypesReceived := make([]int, nevents, nevents)
channel := NewGOBChannel(s.serverReadWriter)
i := 0
ListenForAgentEvents(channel,
func(agent EnvironmentInfo) {
eventTypesReceived[i] = 0
i++
},
func(session SessionInfo) {
eventTypesReceived[i] = 1
i++
},
func(expiryTimeUpdate ExpiryTimeUpdate) {
eventTypesReceived[i] = 2
i++
},
func(hearbeat HeartBeat) {
eventTypesReceived[i] = 3
i++
},
)
s.Equal(eventTypesSent, eventTypesReceived)
return nil
})
}
// This is currently a Noop. No need to test it.
//func (s *AgentServerTestSuite) Test_LIstenForServerEvents() {
//
//}