280 lines
7.3 KiB
Go
280 lines
7.3 KiB
Go
package comms
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"git.wamblee.org/converge/pkg/testsupport"
|
|
"github.com/stretchr/testify/suite"
|
|
"go.uber.org/goleak"
|
|
"io"
|
|
"log"
|
|
"math/rand"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type AgentServerTestSuite struct {
|
|
suite.Suite
|
|
|
|
ctx context.Context
|
|
cancelFunc context.CancelFunc
|
|
pprofServer *http.Server
|
|
|
|
agentReadWriter io.ReadWriteCloser
|
|
serverReadWriter io.ReadWriteCloser
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) SetupSuite() {
|
|
s.pprofServer = testsupport.StartPprof("")
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) TearDownSuite() {
|
|
testsupport.StopPprof(s.ctx, s.pprofServer)
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) SetupTest() {
|
|
ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second)
|
|
s.ctx = ctx
|
|
s.cancelFunc = cancelFunc
|
|
|
|
// Could have also used net.Pipe but net.Pipe uses synchronous communication
|
|
// by default and the bitpipe implementation can become asynchronous when
|
|
// a channels ize > 0 is passed in. Also the test utility respects the context
|
|
// so also deals with cancellation much better than net.Pipe.
|
|
bitpipe := testsupport.NewInmemoryConnection(s.ctx, "inmemory", 10)
|
|
agentReadWriter := bitpipe.Front()
|
|
serverReadWriter := bitpipe.Back()
|
|
s.agentReadWriter = agentReadWriter
|
|
s.serverReadWriter = serverReadWriter
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) TearDownTest() {
|
|
agentProtocolVersion = PROTOCOL_VERSION
|
|
serverProtocolVersion = PROTOCOL_VERSION
|
|
s.cancelFunc()
|
|
goleak.VerifyNone(s.T())
|
|
}
|
|
|
|
func TestAgentServerTestSuite(t *testing.T) {
|
|
suite.Run(t, &AgentServerTestSuite{})
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) TestClientInfoEncodeDecode() {
|
|
for _, clientId := range []string{"abc", "djkdfadfha"} {
|
|
buf := &bytes.Buffer{}
|
|
err := SendClientInfo(buf, clientId)
|
|
s.Nil(err)
|
|
clientIdReceived, err := ReceiveClientInfo(buf)
|
|
s.Nil(err)
|
|
s.Equal(clientId, clientIdReceived)
|
|
}
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) {
|
|
commChannels := testsupport.RunAndWait(
|
|
&s.Suite,
|
|
func() any {
|
|
log.Println("Agent initializing")
|
|
commChannel, err := NewCommChannel(Agent, s.agentReadWriter)
|
|
s.Nil(err)
|
|
return commChannel
|
|
},
|
|
func() any {
|
|
log.Println("Server initializing")
|
|
commChannel, err := NewCommChannel(ConvergeServer, s.serverReadWriter)
|
|
s.Nil(err)
|
|
return commChannel
|
|
},
|
|
)
|
|
|
|
s.Equal(2, len(commChannels))
|
|
agentCommChannel := commChannels[0].(CommChannel)
|
|
serverCommChannel := commChannels[1].(CommChannel)
|
|
return agentCommChannel, serverCommChannel
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) TestNewCommChannel() {
|
|
// Setup Comm channel
|
|
agentCommChannel, serverCommChannel := s.createCommChannel()
|
|
|
|
// verify the side channel is working by sending an object
|
|
testsupport.RunAndWait(
|
|
&s.Suite,
|
|
func() any {
|
|
protocolVersion := ProtocolVersion{Version: 10}
|
|
err := SendWithTimeout[ProtocolVersion](agentCommChannel.SideChannel, protocolVersion)
|
|
s.Nil(err)
|
|
log.Printf("Sent one message %v", protocolVersion)
|
|
return nil
|
|
},
|
|
func() any {
|
|
protocolVersion, err := ReceiveWithTimeout[ProtocolVersion](serverCommChannel.SideChannel)
|
|
s.Nil(err)
|
|
log.Printf("Received one message %v", protocolVersion)
|
|
return nil
|
|
},
|
|
)
|
|
|
|
log.Printf("%v %v", agentCommChannel, serverCommChannel)
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) Test_ConnectThroughYamux() {
|
|
agentCommChannel, serverCommChannel := s.createCommChannel()
|
|
|
|
dataAgentToServer := "hello"
|
|
dataServerToAgent := "bye"
|
|
testsupport.RunAndWait(
|
|
&s.Suite,
|
|
func() any {
|
|
conn, err := agentCommChannel.Session.OpenStream()
|
|
s.Nil(err)
|
|
testsupport.AssertWriteData(&s.Suite, dataAgentToServer, conn)
|
|
testsupport.AssertReadData(&s.Suite, dataServerToAgent, conn)
|
|
return nil
|
|
},
|
|
func() any {
|
|
conn, err := serverCommChannel.Session.Accept()
|
|
|
|
s.Nil(err)
|
|
testsupport.AssertReadData(&s.Suite, dataAgentToServer, conn)
|
|
testsupport.AssertWriteData(&s.Suite, dataServerToAgent, conn)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) Test_Initialization() {
|
|
serverTestId := rand.Int()
|
|
agentShell := "agentshell"
|
|
testsupport.RunAndWait(
|
|
&s.Suite,
|
|
func() any {
|
|
serverInfo, err := AgentInitialization(s.agentReadWriter,
|
|
NewEnvironmentInfo(agentShell))
|
|
s.Nil(err)
|
|
s.Equal(serverTestId, serverInfo.TestId)
|
|
return nil
|
|
},
|
|
func() any {
|
|
serverInfo := ServerInfo{TestId: serverTestId}
|
|
environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
|
|
s.Nil(err)
|
|
s.Equal(agentShell, environmentInfo.Shell)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) Test_InitializationProtocolVersionMismatch() {
|
|
serverProtocolVersion++
|
|
testsupport.RunAndWait(
|
|
&s.Suite,
|
|
func() any {
|
|
serverInfo, err := AgentInitialization(s.agentReadWriter,
|
|
NewEnvironmentInfo("myshell"))
|
|
s.NotNil(err)
|
|
s.True(strings.Contains(strings.ToLower(err.Error()), "protocol"))
|
|
s.Equal(ServerInfo{}, serverInfo)
|
|
return nil
|
|
},
|
|
func() any {
|
|
serverInfo := ServerInfo{TestId: 1000}
|
|
environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
|
|
s.NotNil(err)
|
|
s.True(strings.Contains(strings.ToLower(err.Error()), "protocol"))
|
|
s.Equal(EnvironmentInfo{}, environmentInfo)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) Test_InitializationAgentConnectionClosed() {
|
|
s.agentReadWriter.Close()
|
|
s.checkInitializationFailure()
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) Test_InitializationServerConnectionClosed() {
|
|
s.serverReadWriter.Close()
|
|
s.checkInitializationFailure()
|
|
}
|
|
|
|
func (s *AgentServerTestSuite) checkInitializationFailure() []any {
|
|
return testsupport.RunAndWait(
|
|
&s.Suite,
|
|
func() any {
|
|
serverInfo, err := AgentInitialization(s.agentReadWriter,
|
|
NewEnvironmentInfo("myshell"))
|
|
s.NotNil(err)
|
|
s.Equal(ServerInfo{}, serverInfo)
|
|
return nil
|
|
},
|
|
func() any {
|
|
serverInfo := ServerInfo{TestId: 1000}
|
|
environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
|
|
s.NotNil(err)
|
|
s.Equal(EnvironmentInfo{}, environmentInfo)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// TODO:
|
|
// Tests when connection is close from agent and from server: verify error is returned
|
|
|
|
func (s *AgentServerTestSuite) Test_ListenForAgentEvents() {
|
|
|
|
agentEvents := []any{
|
|
NewEnvironmentInfo("myshell"),
|
|
NewSessionInfo("1", "sftp"),
|
|
NewExpiryTimeUpdate(time.Now().Add(1 * time.Minute)),
|
|
HeartBeat{},
|
|
}
|
|
const nevents = 100
|
|
eventTypesSent := make([]int, nevents, nevents)
|
|
testsupport.RunAndWait(
|
|
&s.Suite,
|
|
func() any {
|
|
channel := NewGOBChannel(s.agentReadWriter)
|
|
for i := range nevents {
|
|
ievent := rand.Int() % len(agentEvents)
|
|
eventTypesSent[i] = ievent
|
|
event := ConvergeMessage{
|
|
Value: agentEvents[ievent],
|
|
}
|
|
err := SendWithTimeout[ConvergeMessage](channel, event)
|
|
s.Nil(err)
|
|
}
|
|
// pending events will still be sent.
|
|
s.agentReadWriter.Close()
|
|
return nil
|
|
},
|
|
func() any {
|
|
eventTypesReceived := make([]int, nevents, nevents)
|
|
channel := NewGOBChannel(s.serverReadWriter)
|
|
i := 0
|
|
ListenForAgentEvents(channel,
|
|
func(agent EnvironmentInfo) {
|
|
eventTypesReceived[i] = 0
|
|
i++
|
|
},
|
|
func(session SessionInfo) {
|
|
eventTypesReceived[i] = 1
|
|
i++
|
|
},
|
|
func(expiryTimeUpdate ExpiryTimeUpdate) {
|
|
eventTypesReceived[i] = 2
|
|
i++
|
|
},
|
|
func(hearbeat HeartBeat) {
|
|
eventTypesReceived[i] = 3
|
|
i++
|
|
},
|
|
)
|
|
s.Equal(eventTypesSent, eventTypesReceived)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// This is currently a Noop. No need to test it.
|
|
//func (s *AgentServerTestSuite) Test_LIstenForServerEvents() {
|
|
//
|
|
//}
|