test case for single agent registration.

This commit is contained in:
Erik Brakkee 2024-08-22 10:28:54 +02:00
parent b37fb4bc67
commit 25ba2cf464
3 changed files with 131 additions and 20 deletions

View File

@ -20,8 +20,8 @@ type AgentServerTestSuite struct {
cancelFunc context.CancelFunc cancelFunc context.CancelFunc
pprofServer *http.Server pprofServer *http.Server
agentConnection io.ReadWriteCloser agentReadWriter io.ReadWriteCloser
serverConnection io.ReadWriteCloser serverReadWriter io.ReadWriteCloser
} }
func (s *AgentServerTestSuite) SetupSuite() { func (s *AgentServerTestSuite) SetupSuite() {
@ -42,10 +42,10 @@ func (s *AgentServerTestSuite) SetupTest() {
// a channels ize > 0 is passed in. Also the test utility respects the context // a channels ize > 0 is passed in. Also the test utility respects the context
// so also deals with cancellation much better than net.Pipe. // so also deals with cancellation much better than net.Pipe.
bitpipe := testsupport.NewInmemoryConnection(s.ctx, "inmemory", 10) bitpipe := testsupport.NewInmemoryConnection(s.ctx, "inmemory", 10)
agentConnection := bitpipe.Front() agentReadWriter := bitpipe.Front()
serverConnection := bitpipe.Back() serverReadWriter := bitpipe.Back()
s.agentConnection = agentConnection s.agentReadWriter = agentReadWriter
s.serverConnection = serverConnection s.serverReadWriter = serverReadWriter
} }
func (suite *AgentServerTestSuite) TearDownTest() { func (suite *AgentServerTestSuite) TearDownTest() {
@ -62,13 +62,13 @@ func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) {
&s.Suite, &s.Suite,
func() any { func() any {
log.Println("Agent initializing") log.Println("Agent initializing")
commChannel, err := NewCommChannel(Agent, s.agentConnection) commChannel, err := NewCommChannel(Agent, s.agentReadWriter)
s.Nil(err) s.Nil(err)
return commChannel return commChannel
}, },
func() any { func() any {
log.Println("Server initializing") log.Println("Server initializing")
commChannel, err := NewCommChannel(ConvergeServer, s.serverConnection) commChannel, err := NewCommChannel(ConvergeServer, s.serverReadWriter)
s.Nil(err) s.Nil(err)
return commChannel return commChannel
}, },
@ -151,7 +151,7 @@ func (s *AgentServerTestSuite) Test_Initialization() {
testsupport.RunAndWait( testsupport.RunAndWait(
&s.Suite, &s.Suite,
func() any { func() any {
serverInfo, err := AgentInitialization(s.agentConnection, serverInfo, err := AgentInitialization(s.agentReadWriter,
NewEnvironmentInfo(agentShell)) NewEnvironmentInfo(agentShell))
s.Nil(err) s.Nil(err)
s.Equal(serverTestId, serverInfo.TestId) s.Equal(serverTestId, serverInfo.TestId)
@ -159,7 +159,7 @@ func (s *AgentServerTestSuite) Test_Initialization() {
}, },
func() any { func() any {
serverInfo := ServerInfo{TestId: serverTestId} serverInfo := ServerInfo{TestId: serverTestId}
environmentInfo, err := ServerInitialization(s.serverConnection, serverInfo) environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
s.Nil(err) s.Nil(err)
s.Equal(agentShell, environmentInfo.Shell) s.Equal(agentShell, environmentInfo.Shell)
return nil return nil
@ -171,7 +171,7 @@ func (s *AgentServerTestSuite) Test_InitializationProtocolVersionMismatch() {
testsupport.RunAndWait( testsupport.RunAndWait(
&s.Suite, &s.Suite,
func() any { func() any {
serverInfo, err := AgentInitialization(s.agentConnection, serverInfo, err := AgentInitialization(s.agentReadWriter,
NewEnvironmentInfo("myshell")) NewEnvironmentInfo("myshell"))
s.NotNil(err) s.NotNil(err)
s.True(strings.Contains(strings.ToLower(err.Error()), "protocol")) s.True(strings.Contains(strings.ToLower(err.Error()), "protocol"))
@ -180,7 +180,7 @@ func (s *AgentServerTestSuite) Test_InitializationProtocolVersionMismatch() {
}, },
func() any { func() any {
serverInfo := ServerInfo{TestId: 1000} serverInfo := ServerInfo{TestId: 1000}
environmentInfo, err := ServerInitialization(s.serverConnection, serverInfo) environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
s.NotNil(err) s.NotNil(err)
s.True(strings.Contains(strings.ToLower(err.Error()), "protocol")) s.True(strings.Contains(strings.ToLower(err.Error()), "protocol"))
s.Equal(EnvironmentInfo{}, environmentInfo) s.Equal(EnvironmentInfo{}, environmentInfo)
@ -189,12 +189,12 @@ func (s *AgentServerTestSuite) Test_InitializationProtocolVersionMismatch() {
} }
func (s *AgentServerTestSuite) Test_InitializationAgentConnectionClosed() { func (s *AgentServerTestSuite) Test_InitializationAgentConnectionClosed() {
s.agentConnection.Close() s.agentReadWriter.Close()
s.checkInitializationFailure() s.checkInitializationFailure()
} }
func (s *AgentServerTestSuite) Test_InitializationServerConnectionClosed() { func (s *AgentServerTestSuite) Test_InitializationServerConnectionClosed() {
s.serverConnection.Close() s.serverReadWriter.Close()
s.checkInitializationFailure() s.checkInitializationFailure()
} }
@ -202,7 +202,7 @@ func (s *AgentServerTestSuite) checkInitializationFailure() []any {
return testsupport.RunAndWait( return testsupport.RunAndWait(
&s.Suite, &s.Suite,
func() any { func() any {
serverInfo, err := AgentInitialization(s.agentConnection, serverInfo, err := AgentInitialization(s.agentReadWriter,
NewEnvironmentInfo("myshell")) NewEnvironmentInfo("myshell"))
s.NotNil(err) s.NotNil(err)
s.Equal(ServerInfo{}, serverInfo) s.Equal(ServerInfo{}, serverInfo)
@ -210,7 +210,7 @@ func (s *AgentServerTestSuite) checkInitializationFailure() []any {
}, },
func() any { func() any {
serverInfo := ServerInfo{TestId: 1000} serverInfo := ServerInfo{TestId: 1000}
environmentInfo, err := ServerInitialization(s.serverConnection, serverInfo) environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
s.NotNil(err) s.NotNil(err)
s.Equal(EnvironmentInfo{}, environmentInfo) s.Equal(EnvironmentInfo{}, environmentInfo)
return nil return nil
@ -233,7 +233,7 @@ func (s *AgentServerTestSuite) Test_ListenForAgentEvents() {
testsupport.RunAndWait( testsupport.RunAndWait(
&s.Suite, &s.Suite,
func() any { func() any {
channel := NewGOBChannel(s.agentConnection) channel := NewGOBChannel(s.agentReadWriter)
for i := range nevents { for i := range nevents {
ievent := rand.Int() % len(agentEvents) ievent := rand.Int() % len(agentEvents)
eventTypesSent[i] = ievent eventTypesSent[i] = ievent
@ -244,12 +244,12 @@ func (s *AgentServerTestSuite) Test_ListenForAgentEvents() {
s.Nil(err) s.Nil(err)
} }
// pending events will still be sent. // pending events will still be sent.
s.agentConnection.Close() s.agentReadWriter.Close()
return nil return nil
}, },
func() any { func() any {
eventTypesReceived := make([]int, nevents, nevents) eventTypesReceived := make([]int, nevents, nevents)
channel := NewGOBChannel(s.serverConnection) channel := NewGOBChannel(s.serverReadWriter)
i := 0 i := 0
ListenForAgentEvents(channel, ListenForAgentEvents(channel,
func(agent EnvironmentInfo) { func(agent EnvironmentInfo) {

View File

@ -116,12 +116,15 @@ func (admin *Admin) AddAgent(hostKey []byte, publicId models.RendezVousId, agent
message = "The server allocated a new id." message = "The server allocated a new id."
} }
publicId = newPublicId publicId = newPublicId
comms.SendRegistrationMessage(conn, comms.AgentRegistration{ err := comms.SendRegistrationMessage(conn, comms.AgentRegistration{
Ok: true, Ok: true,
Message: message, Message: message,
Id: string(publicId), Id: string(publicId),
HostPrivateKey: hostKey, HostPrivateKey: hostKey,
}) })
if err != nil {
return nil, err
}
} else { } else {
comms.SendRegistrationMessage(conn, comms.AgentRegistration{ comms.SendRegistrationMessage(conn, comms.AgentRegistration{
Ok: false, Ok: false,

View File

@ -0,0 +1,108 @@
package admin
import (
"context"
"crypto/rand"
"git.wamblee.org/converge/pkg/comms"
"git.wamblee.org/converge/pkg/models"
"git.wamblee.org/converge/pkg/testsupport"
"github.com/stretchr/testify/suite"
"io"
"net/http"
"testing"
"time"
)
// test cases
//
// Agent only: verify state, verify agentregistration message
// - Connect single agent
// - Connect agent, connect second agent with duplicate id
// -> new id taken out
// - Connect more than 100 agents with the same id
// -> 101th agents gets error
//
// Client connected to agent: Verify clientConnection and agentCOnnection, verify state, verify clientInfo message.
// - Connect agent + connect client with mmtching id
// - Connect agent + connect client with wrong id
//
// Overall:
// - Connect agent, connect 2 clients
// - Connect multiple agents and clients
type AdminTestSuite struct {
suite.Suite
ctx context.Context
cancelFunc context.CancelFunc
pprofServer *http.Server
agentReadWriter io.ReadWriteCloser
serverReadWriter io.ReadWriteCloser
admin *Admin
hostKey []byte
}
func (s *AdminTestSuite) SetupSuite() {
s.pprofServer = testsupport.StartPprof("")
}
func (s *AdminTestSuite) TearDownSuite() {
testsupport.StopPprof(s.ctx, s.pprofServer)
}
func (s *AdminTestSuite) SetupTest() {
ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second)
s.ctx = ctx
s.cancelFunc = cancelFunc
// Could have also used net.Pipe but net.Pipe uses synchronous communication
// by default and the bitpipe implementation can become asynchronous when
// a channels ize > 0 is passed in. Also the test utility respects the context
// so also deals with cancellation much better than net.Pipe.
bitpipe := testsupport.NewInmemoryConnection(s.ctx, "inmemory", 10)
agentReadWriter := bitpipe.Front()
serverReadWriter := bitpipe.Back()
s.agentReadWriter = agentReadWriter
s.serverReadWriter = serverReadWriter
s.admin = NewAdmin()
s.hostKey = make([]byte, 100)
rand.Read(s.hostKey)
}
func (suite *AdminTestSuite) TearDownTest() {
}
func TestAdminTestSuite(t *testing.T) {
suite.Run(t, &AdminTestSuite{})
}
func (s *AdminTestSuite) Test_AgentRegisters() {
publicId := "abc"
testsupport.RunAndWait(
&s.Suite,
func() any {
agentConn, err := s.admin.AddAgent(s.hostKey, models.RendezVousId(publicId), comms.EnvironmentInfo{}, s.serverReadWriter)
s.Nil(err)
s.Equal(publicId, string(agentConn.Info.PublicId))
state := s.admin.CreateNotifification()
s.Equal(1, len(state.Agents))
s.Equal(0, len(state.Clients))
s.Equal(agentConn.Info, state.Agents[agentConn.Info.Guid])
return nil
},
func() any {
// verify registration message received
agentRegistration, err := comms.ReceiveRegistrationMessage(s.agentReadWriter)
s.Nil(err)
s.True(agentRegistration.Ok)
s.Equal(s.hostKey, agentRegistration.HostPrivateKey)
commChannel, err := comms.NewCommChannel(comms.Agent, s.agentReadWriter)
s.Nil(err)
s.NotNil(commChannel)
return nil
})
}