From 2accbaf3fe00aec50f6c7b607601cef808725d77 Mon Sep 17 00:00:00 2001
From: Erik Brakkee <erik@brakkee.org>
Date: Sat, 24 Aug 2024 19:48:53 +0200
Subject: [PATCH] changes for testing agent connect by returning a
 synchronizer() function that must be called after connect to synchronize
 data.

---
 Makefile                                 |  2 +-
 cmd/converge/converge.go                 |  4 +-
 pkg/server/matchmaker/matchmaker.go      | 42 +++++++++++-----
 pkg/server/matchmaker/matchmaker_test.go | 63 ++++++++++++++++++++++++
 4 files changed, 96 insertions(+), 15 deletions(-)

diff --git a/Makefile b/Makefile
index e964ae8..82d42a9 100644
--- a/Makefile
+++ b/Makefile
@@ -14,7 +14,7 @@ vet: fmt
 	go vet ./...
 
 test: build
-	go test -count=1  -v ./...
+	go test -count=1 -v ./...
 
 build: generate vet
 	mkdir -p bin
diff --git a/cmd/converge/converge.go b/cmd/converge/converge.go
index 1dc41cc..8c39649 100644
--- a/cmd/converge/converge.go
+++ b/cmd/converge/converge.go
@@ -205,10 +205,12 @@ func setupWebSockets(admin *matchmaker.MatchMaker, websessions *ui.WebSessions)
 			}
 			_, wsProxyMode := r.URL.Query()["wsproxy"]
 			log.Printf("Got client connection: '%s'\n", publicId)
-			err = admin.Connect(wsProxyMode, publicId, conn)
+			_, synchronizer, err := admin.Connect(wsProxyMode, publicId, conn)
 			if err != nil {
 				log.Printf("Error %v\n", err)
+				return
 			}
+			synchronizer()
 		},
 	}
 
diff --git a/pkg/server/matchmaker/matchmaker.go b/pkg/server/matchmaker/matchmaker.go
index 227e401..5a2e04a 100644
--- a/pkg/server/matchmaker/matchmaker.go
+++ b/pkg/server/matchmaker/matchmaker.go
@@ -55,7 +55,7 @@ func (converge *MatchMaker) Register(publicId models.RendezVousId, conn io.ReadW
 	}
 	publicId = agent.Info.PublicId
 	cleanupFunc := func() {
-		converge.admin.RemoveAgent(publicId)
+		_ = converge.admin.RemoveAgent(publicId)
 		converge.logStatus()
 	}
 	defer func() {
@@ -93,9 +93,16 @@ func (converge *MatchMaker) Register(publicId models.RendezVousId, conn io.ReadW
 	}), nil
 }
 
-func (converge *MatchMaker) Connect(wsProxyMode bool, publicId models.RendezVousId, conn iowrappers2.ReadWriteAddrCloser) error {
-	defer conn.Close()
+type SynchronizeStreamsFunc func()
 
+func (converge *MatchMaker) Connect(wsProxyMode bool,
+	publicId models.RendezVousId, conn iowrappers2.ReadWriteAddrCloser) (clientId models.ClientId, synchronizer SynchronizeStreamsFunc, err error) {
+
+	defer func() {
+		if err != nil {
+			conn.Close()
+		}
+	}()
 	log.Printf("Using wsproxy protocol %v", wsProxyMode)
 	channel := comms.NewGOBChannel(conn)
 	if wsProxyMode {
@@ -106,11 +113,20 @@ func (converge *MatchMaker) Connect(wsProxyMode bool, publicId models.RendezVous
 			})
 		if err != nil {
 			log.Printf("Error sending protocol version to client %v", err)
-			return err
+			return "", nil, err
 		}
 	}
 
 	client, err := converge.admin.AddClient(publicId, conn)
+	cleanUpFunc := func() {
+		converge.admin.RemoveClient(client)
+		converge.logStatus()
+	}
+	defer func() {
+		if err != nil {
+			cleanUpFunc()
+		}
+	}()
 	if err != nil {
 		if wsProxyMode {
 			_ = comms.SendWithTimeout(channel,
@@ -119,12 +135,8 @@ func (converge *MatchMaker) Connect(wsProxyMode bool, publicId models.RendezVous
 					Message: err.Error(),
 				})
 		}
-		return err
+		return "", nil, err
 	}
-	defer func() {
-		converge.admin.RemoveClient(client)
-		converge.logStatus()
-	}()
 	log.Printf("Connecting client and agent: '%s'\n", publicId)
 	if wsProxyMode {
 		err = comms.SendWithTimeout(channel,
@@ -133,17 +145,21 @@ func (converge *MatchMaker) Connect(wsProxyMode bool, publicId models.RendezVous
 				Message: "Connecting to agent",
 			})
 		if err != nil {
-			return fmt.Errorf("Error sending connection info to client: %v", err)
+			return "", nil, fmt.Errorf("Error sending connection info to client: %v", err)
 		}
 		clientEnvironment, err := comms.ReceiveWithTimeout[comms.EnvironmentInfo](channel)
 		if err != nil {
-			return fmt.Errorf("Error receiving environment info from client: %v", err)
+			return "", nil, fmt.Errorf("Error receiving environment info from client: %v", err)
 		}
 		client.Info.EnvironmentInfo = clientEnvironment
 	}
 	converge.logStatus()
-	client.Synchronize()
-	return nil
+
+	return client.Info.ClientId, SynchronizeStreamsFunc(func() {
+		defer conn.Close()
+		defer cleanUpFunc()
+		client.Synchronize()
+	}), nil
 }
 
 func (converge *MatchMaker) logStatus() {
diff --git a/pkg/server/matchmaker/matchmaker_test.go b/pkg/server/matchmaker/matchmaker_test.go
index c0a7b98..80275c6 100644
--- a/pkg/server/matchmaker/matchmaker_test.go
+++ b/pkg/server/matchmaker/matchmaker_test.go
@@ -4,6 +4,7 @@ import (
 	"context"
 	"git.wamblee.org/converge/pkg/comms"
 	"git.wamblee.org/converge/pkg/models"
+	"git.wamblee.org/converge/pkg/support/iowrappers"
 	"git.wamblee.org/converge/pkg/testsupport"
 	"github.com/stretchr/testify/suite"
 	"go.uber.org/goleak"
@@ -104,6 +105,25 @@ func (agent *TestAgent) Register(s *MatchMakerTestSuite) error {
 	return nil
 }
 
+type TestClient struct {
+	clientSideConn io.ReadWriteCloser
+	serverSIdeConn iowrappers.ReadWriteAddrCloser
+}
+
+func NewTestClient(ctx context.Context) *TestClient {
+	a, b := testsupport.CreatePipe(ctx)
+	res := TestClient{
+		clientSideConn: a,
+		serverSIdeConn: iowrappers.NewSimpleReadWriteAddrCloser(b,
+			testsupport.DummyRemoteAddr("remoteaddr")),
+	}
+	return &res
+}
+
+func (c *TestClient) Disconnect() {
+	c.clientSideConn.Close()
+}
+
 func (s *MatchMakerTestSuite) Test_newMatchMaker() {
 	s.checkState(0, 0)
 }
@@ -122,6 +142,49 @@ func (s *MatchMakerTestSuite) Test_singleAgent() {
 	s.checkState(0, 0)
 }
 
+func (s *MatchMakerTestSuite) Test_singleAgentAndClient() {
+	publicId := models.RendezVousId("abc")
+	agent := NewTestAgent(s.ctx)
+
+	waitForAgentFunc := s.registerAgent(publicId, agent)
+	go waitForAgentFunc()
+
+	client := NewTestClient(s.ctx)
+	var clientId models.ClientId
+	testsupport.RunAndWait(
+		&s.Suite,
+		func() any {
+			//server
+			clientIdCreated, synchronizer, err := s.matchMaker.Connect(false, publicId, client.serverSIdeConn)
+			clientId = clientIdCreated
+			s.Nil(err)
+			if err == nil {
+				log.Println("test: synchronizing streams.")
+				go synchronizer()
+			}
+			return nil
+		},
+		func() any {
+			// client, nothing to do with wsproxy mode off.
+			return nil
+		})
+
+	s.checkState(1, 1)
+
+	agentClientSideConn, err := agent.listener.GetConnection(string(clientId))
+	log.Printf("Agent side conn %v", agentClientSideConn)
+	s.Nil(err)
+	testsupport.BidirectionalConnectionCheck(
+		&s.Suite, "testmsg",
+		client.clientSideConn,
+		agentClientSideConn)
+
+	client.Disconnect()
+	// It is the agents choice to exit> The test agent does not exit by default when
+	// there are no more connections.
+	s.checkState(1, 0)
+}
+
 func (s *MatchMakerTestSuite) checkState(nAgents int, nClients int) {
 	s.True(testsupport.CheckCondition(s.ctx, func() bool {
 		return nAgents == len(s.notifier.state.Agents)