package comms import ( "context" "git.wamblee.org/converge/pkg/testsupport" "github.com/stretchr/testify/suite" "io" "log" "math/rand" "net/http" "strings" "testing" "time" ) type AgentServerTestSuite struct { suite.Suite ctx context.Context cancelFunc context.CancelFunc pprofServer *http.Server agentReadWriter io.ReadWriteCloser serverReadWriter io.ReadWriteCloser } func (s *AgentServerTestSuite) SetupSuite() { s.pprofServer = testsupport.StartPprof("") } func (s *AgentServerTestSuite) TearDownSuite() { testsupport.StopPprof(s.ctx, s.pprofServer) } func (s *AgentServerTestSuite) SetupTest() { ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second) s.ctx = ctx s.cancelFunc = cancelFunc // Could have also used net.Pipe but net.Pipe uses synchronous communication // by default and the bitpipe implementation can become asynchronous when // a channels ize > 0 is passed in. Also the test utility respects the context // so also deals with cancellation much better than net.Pipe. bitpipe := testsupport.NewInmemoryConnection(s.ctx, "inmemory", 10) agentReadWriter := bitpipe.Front() serverReadWriter := bitpipe.Back() s.agentReadWriter = agentReadWriter s.serverReadWriter = serverReadWriter } func (suite *AgentServerTestSuite) TearDownTest() { agentProtocolVersion = PROTOCOL_VERSION serverProtocolVersion = PROTOCOL_VERSION } func TestAgentServerTestSuite(t *testing.T) { suite.Run(t, &AgentServerTestSuite{}) } func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) { commChannels := testsupport.RunAndWait( &s.Suite, func() any { log.Println("Agent initializing") commChannel, err := NewCommChannel(Agent, s.agentReadWriter) s.Nil(err) return commChannel }, func() any { log.Println("Server initializing") commChannel, err := NewCommChannel(ConvergeServer, s.serverReadWriter) s.Nil(err) return commChannel }, ) s.Equal(2, len(commChannels)) agentCommChannel := commChannels[0].(CommChannel) serverCommChannel := commChannels[1].(CommChannel) return agentCommChannel, serverCommChannel } func (s *AgentServerTestSuite) TestNewCommChannel() { // Setup Comm channel agentCommChannel, serverCommChannel := s.createCommChannel() // verify the side channel is working by sending an object testsupport.RunAndWait( &s.Suite, func() any { protocolVersion := ProtocolVersion{Version: 10} err := SendWithTimeout[ProtocolVersion](agentCommChannel.SideChannel, protocolVersion) s.Nil(err) log.Printf("Sent one message %v", protocolVersion) return nil }, func() any { protocolVersion, err := ReceiveWithTimeout[ProtocolVersion](serverCommChannel.SideChannel) s.Nil(err) log.Printf("Received one message %v", protocolVersion) return nil }, ) log.Printf("%v %v", agentCommChannel, serverCommChannel) } func (s *AgentServerTestSuite) Test_ConnectThroughYamux() { agentCommChannel, serverCommChannel := s.createCommChannel() dataAgentToServer := "hello" dataServerToAgent := "bye" testsupport.RunAndWait( &s.Suite, func() any { conn, err := agentCommChannel.Session.OpenStream() s.Nil(err) n, err := conn.Write([]byte(dataAgentToServer)) s.Nil(err) s.Equal(len(dataAgentToServer), n) buf := make([]byte, len(dataServerToAgent)) n, err = conn.Read(buf) s.Nil(err) s.Equal(len(dataServerToAgent), n) s.Equal([]byte(dataServerToAgent), buf[:n]) return nil }, func() any { conn, err := serverCommChannel.Session.Accept() s.Nil(err) buf := make([]byte, len(dataAgentToServer)) n, err := conn.Read(buf) s.Nil(err) s.Equal(len(dataAgentToServer), n) s.Equal([]byte(dataAgentToServer), buf[:n]) n, err = conn.Write([]byte(dataServerToAgent)) s.Nil(err) s.Equal(len(dataServerToAgent), n) return nil }) } func (s *AgentServerTestSuite) Test_Initialization() { serverTestId := rand.Int() agentShell := "agentshell" testsupport.RunAndWait( &s.Suite, func() any { serverInfo, err := AgentInitialization(s.agentReadWriter, NewEnvironmentInfo(agentShell)) s.Nil(err) s.Equal(serverTestId, serverInfo.TestId) return nil }, func() any { serverInfo := ServerInfo{TestId: serverTestId} environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo) s.Nil(err) s.Equal(agentShell, environmentInfo.Shell) return nil }) } func (s *AgentServerTestSuite) Test_InitializationProtocolVersionMismatch() { serverProtocolVersion++ testsupport.RunAndWait( &s.Suite, func() any { serverInfo, err := AgentInitialization(s.agentReadWriter, NewEnvironmentInfo("myshell")) s.NotNil(err) s.True(strings.Contains(strings.ToLower(err.Error()), "protocol")) s.Equal(ServerInfo{}, serverInfo) return nil }, func() any { serverInfo := ServerInfo{TestId: 1000} environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo) s.NotNil(err) s.True(strings.Contains(strings.ToLower(err.Error()), "protocol")) s.Equal(EnvironmentInfo{}, environmentInfo) return nil }) } func (s *AgentServerTestSuite) Test_InitializationAgentConnectionClosed() { s.agentReadWriter.Close() s.checkInitializationFailure() } func (s *AgentServerTestSuite) Test_InitializationServerConnectionClosed() { s.serverReadWriter.Close() s.checkInitializationFailure() } func (s *AgentServerTestSuite) checkInitializationFailure() []any { return testsupport.RunAndWait( &s.Suite, func() any { serverInfo, err := AgentInitialization(s.agentReadWriter, NewEnvironmentInfo("myshell")) s.NotNil(err) s.Equal(ServerInfo{}, serverInfo) return nil }, func() any { serverInfo := ServerInfo{TestId: 1000} environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo) s.NotNil(err) s.Equal(EnvironmentInfo{}, environmentInfo) return nil }) } // TODO: // Tests when connection is close from agent and from server: verify error is returned func (s *AgentServerTestSuite) Test_ListenForAgentEvents() { agentEvents := []any{ NewEnvironmentInfo("myshell"), NewSessionInfo("1", "sftp"), NewExpiryTimeUpdate(time.Now().Add(1 * time.Minute)), HeartBeat{}, } const nevents = 100 eventTypesSent := make([]int, nevents, nevents) testsupport.RunAndWait( &s.Suite, func() any { channel := NewGOBChannel(s.agentReadWriter) for i := range nevents { ievent := rand.Int() % len(agentEvents) eventTypesSent[i] = ievent event := ConvergeMessage{ Value: agentEvents[ievent], } err := SendWithTimeout[ConvergeMessage](channel, event) s.Nil(err) } // pending events will still be sent. s.agentReadWriter.Close() return nil }, func() any { eventTypesReceived := make([]int, nevents, nevents) channel := NewGOBChannel(s.serverReadWriter) i := 0 ListenForAgentEvents(channel, func(agent EnvironmentInfo) { eventTypesReceived[i] = 0 i++ }, func(session SessionInfo) { eventTypesReceived[i] = 1 i++ }, func(expiryTimeUpdate ExpiryTimeUpdate) { eventTypesReceived[i] = 2 i++ }, func(hearbeat HeartBeat) { eventTypesReceived[i] = 3 i++ }, ) s.Equal(eventTypesSent, eventTypesReceived) return nil }) } // This is currently a Noop. No need to test it. //func (s *AgentServerTestSuite) Test_LIstenForServerEvents() { // //}