From ba8bd15cf761574a287c0b5399791fd8a05012f0 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Fri, 23 Aug 2024 21:10:01 +0200 Subject: [PATCH] work in progress for testing the matchmaker. --- cmd/converge/converge.go | 4 +- go.mod | 1 + go.sum | 2 + pkg/server/admin/admin_test.go | 42 ++--- pkg/server/matchmaker/matchmaker.go | 32 +++- pkg/server/matchmaker/matchmaker_test.go | 163 ++++++++++++++++++ pkg/testsupport/inmemoryconnection.go | 10 +- .../listener.go} | 26 +-- pkg/testsupport/utils.go | 32 ++++ 9 files changed, 256 insertions(+), 56 deletions(-) create mode 100644 pkg/server/matchmaker/matchmaker_test.go rename pkg/{server/admin/listener_test.go => testsupport/listener.go} (69%) diff --git a/cmd/converge/converge.go b/cmd/converge/converge.go index 8b79f59..1dc41cc 100644 --- a/cmd/converge/converge.go +++ b/cmd/converge/converge.go @@ -186,10 +186,12 @@ func setupWebSockets(admin *matchmaker.MatchMaker, websessions *ui.WebSessions) return } log.Printf("Got registration connection: '%s'\n", publicId) - err = admin.Register(publicId, conn) + waitFunc, err := admin.Register(publicId, conn) if err != nil { log.Printf("Error %v\n", err) + return } + waitFunc() }, } diff --git a/go.mod b/go.mod index 94c8c2f..46d95e1 100755 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/sys v0.22.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 9d0bb92..7c5edba 100755 --- a/go.sum +++ b/go.sum @@ -48,6 +48,8 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/pkg/server/admin/admin_test.go b/pkg/server/admin/admin_test.go index 74579ab..c4f9d96 100644 --- a/pkg/server/admin/admin_test.go +++ b/pkg/server/admin/admin_test.go @@ -38,11 +38,6 @@ type AdminTestSuite struct { hostKey []byte } -func (s *AdminTestSuite) createPipe() (io.ReadWriteCloser, io.ReadWriteCloser) { - bitpipe := testsupport.NewInmemoryConnection(s.ctx, "inmemory", 10) - return bitpipe.Front(), bitpipe.Back() -} - func (s *AdminTestSuite) SetupSuite() { s.pprofServer = testsupport.StartPprof("") } @@ -77,7 +72,7 @@ type AddAgentResult struct { } func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId string) (AddAgentResult, AgentRegisterResult) { - agentToServerRW, serverToAgentRW := s.createPipe() + agentToServerRW, serverToAgentRW := testsupport.CreatePipe(s.ctx) res := testsupport.RunAndWait( &s.Suite, func() any { @@ -102,7 +97,7 @@ func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId stri type AgentRegisterResult struct { registration comms.AgentRegistration commChannel comms.CommChannel - listener *TestAgentListener + listener *testsupport.TestAgentListener err error } @@ -224,7 +219,8 @@ func (s *AdminTestSuite) Test_MultipleAgentsAndClients() { defer wg.Done() iclient := i client := fmt.Sprintf("client %d/%d", iagent, iclient) - s.connectClientToAgent(client, publicId, data, agentRes) + _, err := s.connectClientToAgent(client, publicId, data, agentRes) + s.Nil(err) }() } }() @@ -235,7 +231,7 @@ func (s *AdminTestSuite) Test_MultipleAgentsAndClients() { func (s *AdminTestSuite) connectClientToAgent( client string, publicId string, data string, agentRes AgentRegisterResult) (*clientConnection, error) { - serverToClientRW, clientToServerRW := s.createPipe() + serverToClientRW, clientToServerRW := testsupport.CreatePipe(s.ctx) // TODO refactoring // - TestAgentListener should run in a separate go routine @@ -295,36 +291,18 @@ func (s *AdminTestSuite) connectClientToAgent( go clientConn.Synchronize() msg := fmt.Sprintf("end-to-end %s", client) // verify bidirectional communication - s.bidirectionalConnectionCheck(msg, clientToServerRW, agentToServerYamux) + testsupport.BidirectionalConnectionCheck(&s.Suite, msg, clientToServerRW, agentToServerYamux) return clientConn, nil } -func (s *AdminTestSuite) bidirectionalConnectionCheck(msg string, clientToServerRW io.ReadWriteCloser, agentToServerYamux io.ReadWriter) { - data1 := msg + " client->agent" - data2 := msg + " agent->client" - log.Printf("BIDIRECTIONAL CHECK %v -> %v", msg, agentToServerYamux) - testsupport.RunAndWait( - &s.Suite, - func() any { - testsupport.AssertWriteData(&s.Suite, data1, clientToServerRW) - testsupport.AssertReadData(&s.Suite, data2, clientToServerRW) - return nil - }, - func() any { - testsupport.AssertReadData(&s.Suite, data1, agentToServerYamux) - testsupport.AssertWriteData(&s.Suite, data2, agentToServerYamux) - return nil - }) -} - func (s *AdminTestSuite) Test_connectClientUnknownRendezVousId() { publicId := "abc" serverRes, agentRes := s.agentRegisters(publicId, publicId) s.Nil(serverRes.err) s.Nil(agentRes.err) - serverToClientRW, _ := s.createPipe() + serverToClientRW, _ := testsupport.CreatePipe(s.ctx) _, err := s.admin.AddClient(models.RendezVousId(publicId+"sothatisunknown"), iowrappers.NewSimpleReadWriteAddrCloser(serverToClientRW, testsupport.DummyRemoteAddr("remoteaddr"))) @@ -363,7 +341,7 @@ func (s *AdminTestSuite) agentRegistration(agentToServerRW io.ReadWriteCloser) A s.NotNil(commChannel) baseListener := comms.NewAgentListener(commChannel.Session) - listener := NewTestListener(s.ctx, baseListener) + listener := testsupport.NewTestListener(s.ctx, baseListener) return AgentRegisterResult{ registration: agentRegistration, @@ -381,10 +359,10 @@ func (s *AdminTestSuite) connectClient(publicId string, serverToClientRW io.Read return clientConn } -func (s *AdminTestSuite) clientConnection(clientId models.ClientId, listener *TestAgentListener) (net.Conn, error) { +func (s *AdminTestSuite) clientConnection(clientId models.ClientId, listener *testsupport.TestAgentListener) (net.Conn, error) { // agent log.Printf("clientConnection: Getting connection for %v", clientId) - agentToServerYamux, err := listener.getConnection(clientId) + agentToServerYamux, err := listener.GetConnection(string(clientId)) log.Printf("clientConnection: Got connection %v for client %v", agentToServerYamux, clientId) s.Nil(err) return agentToServerYamux, err diff --git a/pkg/server/matchmaker/matchmaker.go b/pkg/server/matchmaker/matchmaker.go index f1fe1ca..227e401 100644 --- a/pkg/server/matchmaker/matchmaker.go +++ b/pkg/server/matchmaker/matchmaker.go @@ -33,24 +33,35 @@ func NewMatchMaker(notifier Notifier) *MatchMaker { return &converge } -func (converge *MatchMaker) Register(publicId models.RendezVousId, conn io.ReadWriteCloser) error { +func (converge *MatchMaker) Close() { + converge.admin.Close() +} + +type WaitForAgentFunc func() + +func (converge *MatchMaker) Register(publicId models.RendezVousId, conn io.ReadWriteCloser) (waitForAgentFunc WaitForAgentFunc, err error) { serverInfo := comms.ServerInfo{} agentInfo, err := comms.ServerInitialization(conn, serverInfo) if err != nil { - return err + return nil, err } agent, err := converge.admin.AddAgent(hostPrivateKey, publicId, agentInfo, conn) converge.logStatus() if err != nil { - return err + return nil, err } publicId = agent.Info.PublicId - defer func() { + cleanupFunc := func() { converge.admin.RemoveAgent(publicId) converge.logStatus() + } + defer func() { + if err != nil { + cleanupFunc() + } }() go func() { @@ -72,11 +83,14 @@ func (converge *MatchMaker) Register(publicId models.RendezVousId, conn io.ReadW }) }() - go log.Printf("agentConnection registered: '%s'\n", publicId) - for !agent.CommChannel.Session.IsClosed() { - time.Sleep(250 * time.Millisecond) - } - return nil + return WaitForAgentFunc(func() { + defer cleanupFunc() + log.Printf("agentConnection registered: '%s'\n", publicId) + for !agent.CommChannel.Session.IsClosed() { + time.Sleep(250 * time.Millisecond) + } + log.Printf("Agent disconnected") + }), nil } func (converge *MatchMaker) Connect(wsProxyMode bool, publicId models.RendezVousId, conn iowrappers2.ReadWriteAddrCloser) error { diff --git a/pkg/server/matchmaker/matchmaker_test.go b/pkg/server/matchmaker/matchmaker_test.go new file mode 100644 index 0000000..c0a7b98 --- /dev/null +++ b/pkg/server/matchmaker/matchmaker_test.go @@ -0,0 +1,163 @@ +package matchmaker + +import ( + "context" + "git.wamblee.org/converge/pkg/comms" + "git.wamblee.org/converge/pkg/models" + "git.wamblee.org/converge/pkg/testsupport" + "github.com/stretchr/testify/suite" + "go.uber.org/goleak" + "io" + "log" + "net/http" + "testing" + "time" +) + +type MatchMakerTestSuite struct { + suite.Suite + + ctx context.Context + cancelFunc context.CancelFunc + pprofServer *http.Server + + notifier *TestNotifier + matchMaker *MatchMaker +} + +type TestNotifier struct { + // last reported state + state *models.State +} + +func (notifier *TestNotifier) Publish(state *models.State) { + notifier.state = state +} + +func (s *MatchMakerTestSuite) SetupSuite() { + s.pprofServer = testsupport.StartPprof("") +} + +func (s *MatchMakerTestSuite) TearDownSuite() { + testsupport.StopPprof(s.ctx, s.pprofServer) +} + +func (s *MatchMakerTestSuite) SetupTest() { + ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second) + s.ctx = ctx + s.cancelFunc = cancelFunc + + s.notifier = &TestNotifier{} + s.matchMaker = NewMatchMaker(s.notifier) +} + +func (s *MatchMakerTestSuite) TearDownTest() { + s.matchMaker.Close() + s.cancelFunc() + goleak.VerifyNone(s.T()) +} + +func TestMatchMakerTestSuite(t *testing.T) { + suite.Run(t, &MatchMakerTestSuite{}) +} + +type TestAgent struct { + agentSideConn io.ReadWriteCloser + serverSIdeConn io.ReadWriteCloser + + agentRegistration comms.AgentRegistration + commChannel comms.CommChannel + listener *testsupport.TestAgentListener +} + +func NewTestAgent(ctx context.Context) *TestAgent { + res := TestAgent{} + a, s := testsupport.CreatePipe(ctx) + res.agentSideConn = a + res.serverSIdeConn = s + return &res +} + +func (agent *TestAgent) Disconnect() { + agent.agentSideConn.Close() +} + +func (agent *TestAgent) Initialize(s *MatchMakerTestSuite) (comms.ServerInfo, error) { + return comms.AgentInitialization(agent.agentSideConn, comms.NewEnvironmentInfo("bash")) +} + +func (agent *TestAgent) Register(s *MatchMakerTestSuite) error { + agentRegistration, err := comms.ReceiveRegistrationMessage(agent.agentSideConn) + if err != nil { + return err + } + agent.agentRegistration = agentRegistration + commChannel, err := comms.NewCommChannel(comms.Agent, agent.agentSideConn) + if err != nil { + return err + } + s.NotNil(commChannel) + agent.commChannel = commChannel + + baseListener := comms.NewAgentListener(commChannel.Session) + agent.listener = testsupport.NewTestListener(s.ctx, baseListener) + return nil +} + +func (s *MatchMakerTestSuite) Test_newMatchMaker() { + s.checkState(0, 0) +} + +func (s *MatchMakerTestSuite) Test_singleAgent() { + publicId := models.RendezVousId("abc") + agent := NewTestAgent(s.ctx) + waitForAgentFunc := s.registerAgent(publicId, agent) + + s.checkState(1, 0) + + // required for connection loss detection + go waitForAgentFunc() + + agent.Disconnect() + s.checkState(0, 0) +} + +func (s *MatchMakerTestSuite) checkState(nAgents int, nClients int) { + s.True(testsupport.CheckCondition(s.ctx, func() bool { + return nAgents == len(s.notifier.state.Agents) + })) + s.True(testsupport.CheckCondition(s.ctx, func() bool { + return nClients == len(s.notifier.state.Clients) + })) +} + +func (s *MatchMakerTestSuite) registerAgent(publicId models.RendezVousId, agent *TestAgent) WaitForAgentFunc { + res := testsupport.RunAndWait( + &s.Suite, + func() any { + // ignore waitFunc for now. + waitFunc, err := s.matchMaker.Register(publicId, agent.serverSIdeConn) + s.Nil(err) + log.Printf("MatchMaskerTest: Agent registered by server") + return waitFunc + }, + func() any { + _, err := agent.Initialize(s) + if err != nil { + s.Nil(err) + return nil + } + err = agent.Register(s) + if err != nil { + s.Nil(err) + return nil + } + log.Println("MatchMakerTest: Agent connected to server") + return nil + }) + + if res[0] == nil { + return nil + } + return res[0].(WaitForAgentFunc) +} diff --git a/pkg/testsupport/inmemoryconnection.go b/pkg/testsupport/inmemoryconnection.go index 3673e03..73842d9 100644 --- a/pkg/testsupport/inmemoryconnection.go +++ b/pkg/testsupport/inmemoryconnection.go @@ -1,6 +1,9 @@ package testsupport -import "context" +import ( + "context" + "io" +) type InmemoryConnection struct { ctx context.Context @@ -31,3 +34,8 @@ func (bitpipe *InmemoryConnection) Back() *ChannelReadWriteCloser { func pipe(ctx context.Context, receiveBuffer <-chan []byte, sendBuffer chan<- []byte, remoteAddr string) *ChannelReadWriteCloser { return NewChannelReadWriteCloser(ctx, receiveBuffer, sendBuffer) } + +func CreatePipe(ctx context.Context) (io.ReadWriteCloser, io.ReadWriteCloser) { + bitpipe := NewInmemoryConnection(ctx, "inmemory", 10) + return bitpipe.Front(), bitpipe.Back() +} diff --git a/pkg/server/admin/listener_test.go b/pkg/testsupport/listener.go similarity index 69% rename from pkg/server/admin/listener_test.go rename to pkg/testsupport/listener.go index 65eb846..a87c578 100644 --- a/pkg/server/admin/listener_test.go +++ b/pkg/testsupport/listener.go @@ -1,16 +1,16 @@ -package admin +package testsupport import ( "context" "errors" - "git.wamblee.org/converge/pkg/models" "log" "net" "sync" ) // Extension of agentlistener for testing. It can accept all connections and puts them into a map based -// on clientId after which a client can retrieve the accepted connection based on client id. +// on clientId after which a client can retrieve the accepted connection based on local address (which is the +// client id. type TestAgentListener struct { net.Listener @@ -18,7 +18,7 @@ type TestAgentListener struct { ctx context.Context mutex sync.Mutex cond *sync.Cond - connections map[models.ClientId]net.Conn + connections map[string]net.Conn } func NewTestListener(ctx context.Context, listener net.Listener) *TestAgentListener { @@ -26,7 +26,7 @@ func NewTestListener(ctx context.Context, listener net.Listener) *TestAgentListe ctx: ctx, Listener: listener, mutex: sync.Mutex{}, - connections: make(map[models.ClientId]net.Conn), + connections: make(map[string]net.Conn), } res.cond = sync.NewCond(&res.mutex) @@ -56,23 +56,23 @@ func (l *TestAgentListener) Accept() (net.Conn, error) { if err != nil { return nil, err } - clientId := models.ClientId(conn.LocalAddr().String()) - log.Printf("testlistener: Storing connection %v %v", clientId, conn) + localAddr := conn.LocalAddr().String() + log.Printf("testlistener: Storing connection %v %v", localAddr, conn) l.mutex.Lock() defer l.mutex.Unlock() - l.connections[clientId] = conn - log.Printf("testlistener: broadcasting %v", clientId) + l.connections[localAddr] = conn + log.Printf("testlistener: broadcasting %v", localAddr) l.cond.Broadcast() return conn, err } -func (l *TestAgentListener) getConnection(clientId models.ClientId) (net.Conn, error) { +func (l *TestAgentListener) GetConnection(localAddr string) (net.Conn, error) { l.mutex.Lock() defer l.mutex.Unlock() // We need to check the condition before the first cond.wait as well. Otherwise, a broadcast sent // at this point in time will not be caught, and if there are no further broadcasts happening, then // the code will hang her. - for ok := l.connections[clientId] != nil; !ok; ok = l.connections[clientId] != nil { + for ok := l.connections[localAddr] != nil; !ok; ok = l.connections[localAddr] != nil { log.Println("Listener cond wait") l.cond.Wait() log.Println("Listener awoken") @@ -82,6 +82,6 @@ func (l *TestAgentListener) getConnection(clientId models.ClientId) (net.Conn, e default: } } - log.Printf("Returning connection %v %v", clientId, l.connections[clientId]) - return l.connections[clientId], nil + log.Printf("Returning connection %v %v", localAddr, l.connections[localAddr]) + return l.connections[localAddr], nil } diff --git a/pkg/testsupport/utils.go b/pkg/testsupport/utils.go index c7ec120..00edd23 100644 --- a/pkg/testsupport/utils.go +++ b/pkg/testsupport/utils.go @@ -96,3 +96,35 @@ func PrintStackTraces() { log.Println(string(buf)) log.Println("") } + +func BidirectionalConnectionCheck(s *suite.Suite, msg string, clientToServerRW io.ReadWriteCloser, agentToServerYamux io.ReadWriter) { + data1 := msg + " -> " + data2 := msg + " <- " + log.Printf("BIDIRECTIONAL CHECK %v", msg) + RunAndWait( + s, + func() any { + AssertWriteData(s, data1, clientToServerRW) + AssertReadData(s, data2, clientToServerRW) + return nil + }, + func() any { + AssertReadData(s, data1, agentToServerYamux) + AssertWriteData(s, data2, agentToServerYamux) + return nil + }) +} + +// having the return type bool forces the check to be done in the test code +// leading to more clear error messages. +func CheckCondition(ctx context.Context, condition func() bool) bool { + for !condition() { + select { + case <-ctx.Done(): + return false + default: + time.Sleep(1 * time.Millisecond) + } + } + return true +}