From ea0b4282ba809e41ae5be1cf6037c1b2e6927624 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Wed, 21 Aug 2024 17:48:47 +0200 Subject: [PATCH] test for ListenForAgentEvents implemented. --- pkg/comms/agentserver.go | 4 ++- pkg/comms/agentserver_test.go | 50 +++++++++++++++++++++++++++++ pkg/comms/events.go | 2 +- pkg/server/matchmaker/matchmaker.go | 3 ++ pkg/testsupport/utils.go | 7 ++++ 5 files changed, 64 insertions(+), 2 deletions(-) diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index 6a9eeed..0437e82 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -91,7 +91,8 @@ func SetupHeartBeat(commChannel CommChannel) { func ListenForAgentEvents(channel GOBChannel, agentInfo func(agent EnvironmentInfo), sessionInfo func(session SessionInfo), - expiryTimeUpdate func(session ExpiryTimeUpdate)) { + expiryTimeUpdate func(session ExpiryTimeUpdate), + heartBeat func(heartbeat HeartBeat)) { for { var result ConvergeMessage err := channel.Decoder.Decode(&result) @@ -115,6 +116,7 @@ func ListenForAgentEvents(channel GOBChannel, // for not ignoring, can also implement behavior // when heartbeat not received but hearbeat is only // intended to keep the connection up + heartBeat(v) default: fmt.Printf(" Unknown type: %v %T\n", v, v) diff --git a/pkg/comms/agentserver_test.go b/pkg/comms/agentserver_test.go index 061b897..d37622c 100644 --- a/pkg/comms/agentserver_test.go +++ b/pkg/comms/agentserver_test.go @@ -165,6 +165,56 @@ func (s *AgentServerTestSuite) Test_Initialization() { func (s *AgentServerTestSuite) Test_ListenForAgentEvents() { + agentEvents := []any{ + NewEnvironmentInfo("myshell"), + NewSessionInfo("1", "sftp"), + NewExpiryTimeUpdate(time.Now().Add(1 * time.Minute)), + HeartBeat{}, + } + const nevents = 100 + eventTypesSent := make([]int, nevents, nevents) + testsupport.RunAndWait( + &s.Suite, + func() any { + channel := NewGOBChannel(s.agentConnection) + for i := range nevents { + ievent := rand.Int() % len(agentEvents) + eventTypesSent[i] = ievent + event := ConvergeMessage{ + Value: agentEvents[ievent], + } + err := SendWithTimeout[ConvergeMessage](channel, event) + s.Nil(err) + } + // pending events will still be sent. + s.agentConnection.Close() + return nil + }, + func() any { + eventTypesReceived := make([]int, nevents, nevents) + channel := NewGOBChannel(s.serverConnection) + i := 0 + ListenForAgentEvents(channel, + func(agent EnvironmentInfo) { + eventTypesReceived[i] = 0 + i++ + }, + func(session SessionInfo) { + eventTypesReceived[i] = 1 + i++ + }, + func(expiryTimeUpdate ExpiryTimeUpdate) { + eventTypesReceived[i] = 2 + i++ + }, + func(hearbeat HeartBeat) { + eventTypesReceived[i] = 3 + i++ + }, + ) + s.Equal(eventTypesSent, eventTypesReceived) + return nil + }) } func (s *AgentServerTestSuite) Test_LIstenForServerEvents() { diff --git a/pkg/comms/events.go b/pkg/comms/events.go index cf4d88f..41e5beb 100644 --- a/pkg/comms/events.go +++ b/pkg/comms/events.go @@ -67,7 +67,7 @@ type AgentRegistration struct { // Generic wrapper message required to send messages of arbitrary type type ConvergeMessage struct { - Value interface{} + Value any } func NewEnvironmentInfo(shell string) EnvironmentInfo { diff --git a/pkg/server/matchmaker/matchmaker.go b/pkg/server/matchmaker/matchmaker.go index 8542992..f1fe1ca 100644 --- a/pkg/server/matchmaker/matchmaker.go +++ b/pkg/server/matchmaker/matchmaker.go @@ -66,6 +66,9 @@ func (converge *MatchMaker) Register(publicId models.RendezVousId, conn io.ReadW func(expiry comms.ExpiryTimeUpdate) { agent.Info.SetExpiryTime(expiry.ExpiryTime) converge.logStatus() + }, + func(heartbeat comms.HeartBeat) { + // Empty }) }() diff --git a/pkg/testsupport/utils.go b/pkg/testsupport/utils.go index 29787d1..ca0c5f0 100644 --- a/pkg/testsupport/utils.go +++ b/pkg/testsupport/utils.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/suite" "log" "net/http" + "os" _ "runtime/pprof" "sync" "time" @@ -30,6 +31,9 @@ func RunAndWait(suite *suite.Suite, functions ...TestFunction) []any { } func StartPprof(port string) *http.Server { + if os.Getenv("PPROF") == "" { + return nil + } if port == "" { port = ":9000" } @@ -49,6 +53,9 @@ func StartPprof(port string) *http.Server { } func StopPprof(ctx context.Context, server *http.Server) { + if os.Getenv("PPROF") == "" { + return + } err := server.Shutdown(ctx) if err != nil { log.Println("Error shutting down test pprof server")