diff --git a/pkg/server/admin/admin_test.go b/pkg/server/admin/admin_test.go index b235e99..0457cfe 100644 --- a/pkg/server/admin/admin_test.go +++ b/pkg/server/admin/admin_test.go @@ -3,6 +3,7 @@ package admin import ( "context" "crypto/rand" + "fmt" "git.wamblee.org/converge/pkg/comms" "git.wamblee.org/converge/pkg/models" "git.wamblee.org/converge/pkg/testsupport" @@ -71,33 +72,69 @@ func TestAdminTestSuite(t *testing.T) { suite.Run(t, &AdminTestSuite{}) } -func (s *AdminTestSuite) Test_AgentRegisters() { - publicId := "abc" +func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId string) *agentConnection { agentRW, serverRW := s.createPipe() - testsupport.RunAndWait( + res := testsupport.RunAndWait( &s.Suite, func() any { - agentConn, err := s.admin.AddAgent( - s.hostKey, models.RendezVousId(publicId), comms.EnvironmentInfo{}, - serverRW) - s.Nil(err) - s.Equal(publicId, string(agentConn.Info.PublicId)) - state := s.admin.CreateNotifification() - s.Equal(1, len(state.Agents)) - s.Equal(0, len(state.Clients)) - s.Equal(agentConn.Info, state.Agents[agentConn.Info.Guid]) - return nil + return s.addAgent(requestedPublicId, serverRW) }, func() any { - // verify registration message received - agentRegistration, err := comms.ReceiveRegistrationMessage(agentRW) - s.Nil(err) - s.True(agentRegistration.Ok) - s.Equal(s.hostKey, agentRegistration.HostPrivateKey) - - commChannel, err := comms.NewCommChannel(comms.Agent, agentRW) - s.Nil(err) - s.NotNil(commChannel) - return nil + return s.agentRegistration(assignedPublicId, agentRW) }) + return res[0].(*agentConnection) +} + +func (s *AdminTestSuite) Test_AgentRegisters() { + agentConn := s.agentRegisters("abc", "abc") + state := s.admin.CreateNotifification() + s.Equal(1, len(state.Agents)) + s.Equal(0, len(state.Clients)) + s.Equal(agentConn.Info, state.Agents[agentConn.Info.Guid]) +} + +func (s *AdminTestSuite) Test_ManyAgentsRegister() { + N := 10 + agentRegistrations := make([]testsupport.TestFunction, N) + for i := range N { + publicId := fmt.Sprintf("abc%d", i) + agentRegistrations[i] = func() any { + return s.agentRegisters(publicId, publicId) + } + } + res := testsupport.RunAndWait( + &s.Suite, + agentRegistrations...) + state := s.admin.CreateNotifification() + s.Equal(len(res), len(state.Agents)) + s.Equal(0, len(state.Clients)) + for _, entry := range res { + agentConn := entry.(*agentConnection) + s.Equal(agentConn.Info, state.Agents[agentConn.Info.Guid]) + } + +} + +func (s *AdminTestSuite) agentRegistration(expectedPublicId string, agentRW io.ReadWriteCloser) any { + // verify registration message received + agentRegistration, err := comms.ReceiveRegistrationMessage(agentRW) + s.Nil(err) + s.True(agentRegistration.Ok) + s.Equal(expectedPublicId, agentRegistration.Id) + s.Equal(s.hostKey, agentRegistration.HostPrivateKey) + + commChannel, err := comms.NewCommChannel(comms.Agent, agentRW) + s.Nil(err) + s.NotNil(commChannel) + return nil +} + +func (s *AdminTestSuite) addAgent(publicId string, serverRW io.ReadWriteCloser) any { + agentConn, err := s.admin.AddAgent( + s.hostKey, models.RendezVousId(publicId), comms.EnvironmentInfo{}, + serverRW) + s.Nil(err) + s.Equal(publicId, string(agentConn.Info.PublicId)) + + return agentConn }