From 670a7053264ea7e1a795c6d7f798b16cc7afaf0e Mon Sep 17 00:00:00 2001
From: Erik Brakkee <erik@brakkee.org>
Date: Fri, 23 Aug 2024 16:32:41 +0200
Subject: [PATCH] test multiple clients now working with a fully concurrent
 registration of clients.

---
 pkg/comms/agentlistener.go        |   2 -
 pkg/server/admin/admin_test.go    | 160 ++++++++++++++++++++----------
 pkg/server/admin/listener_test.go |  87 ++++++++++++++++
 pkg/testsupport/utils.go          |  12 ++-
 4 files changed, 207 insertions(+), 54 deletions(-)
 create mode 100644 pkg/server/admin/listener_test.go

diff --git a/pkg/comms/agentlistener.go b/pkg/comms/agentlistener.go
index 0ddaabc..71f7db8 100644
--- a/pkg/comms/agentlistener.go
+++ b/pkg/comms/agentlistener.go
@@ -2,7 +2,6 @@ package comms
 
 import (
 	"git.wamblee.org/converge/pkg/support/websocketutil"
-	"log"
 	"net"
 )
 
@@ -43,7 +42,6 @@ func (listener AgentListener) Accept() (net.Conn, error) {
 		return nil, err
 	}
 	conn = NewLocalAddrHackConn(conn, clientId)
-	log.Printf("ACCEPT %v %v", clientId, conn)
 	return conn, nil
 }
 
diff --git a/pkg/server/admin/admin_test.go b/pkg/server/admin/admin_test.go
index e9d1ed0..74579ab 100644
--- a/pkg/server/admin/admin_test.go
+++ b/pkg/server/admin/admin_test.go
@@ -3,14 +3,17 @@ package admin
 import (
 	"context"
 	"crypto/rand"
+	"errors"
 	"fmt"
 	"git.wamblee.org/converge/pkg/comms"
 	"git.wamblee.org/converge/pkg/models"
 	"git.wamblee.org/converge/pkg/support/iowrappers"
 	"git.wamblee.org/converge/pkg/testsupport"
 	"github.com/stretchr/testify/suite"
+	"go.uber.org/goleak"
 	"io"
 	"log"
+	"net"
 	"net/http"
 	"strings"
 	"sync"
@@ -61,7 +64,7 @@ func (s *AdminTestSuite) SetupTest() {
 func (s *AdminTestSuite) TearDownTest() {
 	s.admin.Close()
 	s.cancelFunc()
-	//goleak.VerifyNone(s.T())
+	goleak.VerifyNone(s.T())
 }
 
 func TestAdminTestSuite(t *testing.T) {
@@ -73,12 +76,6 @@ type AddAgentResult struct {
 	err       error
 }
 
-type AgentRegisterResult struct {
-	registration comms.AgentRegistration
-	commChannel  comms.CommChannel
-	err          error
-}
-
 func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId string) (AddAgentResult, AgentRegisterResult) {
 	agentToServerRW, serverToAgentRW := s.createPipe()
 	res := testsupport.RunAndWait(
@@ -102,6 +99,13 @@ func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId stri
 	return res[0].(AddAgentResult), res[1].(AgentRegisterResult)
 }
 
+type AgentRegisterResult struct {
+	registration comms.AgentRegistration
+	commChannel  comms.CommChannel
+	listener     *TestAgentListener
+	err          error
+}
+
 func (s *AdminTestSuite) Test_AgentRegisters() {
 	publicId := "abc"
 	res, _ := s.agentRegisters(publicId, publicId)
@@ -162,29 +166,18 @@ func (s *AdminTestSuite) Test_agentDuplicateId() {
 	s.False(agentSideResult.registration.Ok)
 }
 
-func (s *AdminTestSuite) Test_connectClient() {
+func (s *AdminTestSuite) Test_connectClient() error {
 	publicId := "abc"
 	serverRes, agentRes := s.agentRegisters(publicId, publicId)
 	s.Nil(serverRes.err)
 	s.Nil(agentRes.err)
 
-	serverToClientRW, clientToServerRW := s.createPipe()
-
 	data := "connect client test msg"
-	res := testsupport.RunAndWait(
-		&s.Suite,
-		func() any {
-			return s.connectClient(publicId, serverToClientRW, data)
-		},
-		func() any {
-			return s.clientConnection("0", agentRes, data)
-		})
-
-	// bidirectional communication check
-	clientConn := res[0].(*clientConnection)
-	agentToServerYamux := res[1].(io.ReadWriter)
-	go clientConn.Synchronize()
-	s.bidirectionalConnectionCheck("mymessage", clientToServerRW, agentToServerYamux)
+	clientConn, err := s.connectClientToAgent("singleclient", publicId, data, agentRes)
+	s.Nil(err)
+	if err != nil {
+		return err
+	}
 
 	// verify state
 	state := s.admin.CreateNotifification()
@@ -194,18 +187,20 @@ func (s *AdminTestSuite) Test_connectClient() {
 
 	// removing the client will close all connections, we test this by writing to the connections
 	// after removing the client.
-	s.admin.RemoveClient(clientConn)
+	err = s.admin.RemoveClient(clientConn)
+	s.Nil(err)
 	buf := make([]byte, 10)
-	_, err := clientConn.clientConnection.Write(buf)
+	_, err = clientConn.clientConnection.Write(buf)
 	s.NotNil(err)
 	s.True(strings.Contains(err.Error(), "closed"))
 	_, err = clientConn.agentConnection.Write(buf)
 	s.NotNil(err)
 	s.True(strings.Contains(err.Error(), "closed"))
+	return nil
 }
 
 func (s *AdminTestSuite) Test_MultipleAgentsAndClients() {
-	clientCounts := []int{23, 5, 3, 1}
+	clientCounts := []int{10, 5, 37, 1, 29}
 
 	wg := sync.WaitGroup{}
 	for iagent, clientCount := range clientCounts {
@@ -224,9 +219,13 @@ func (s *AdminTestSuite) Test_MultipleAgentsAndClients() {
 				// created in a map base on client id. The client can then retrieve the
 				// connection based on the client id and should also wait until the
 				// connection is available.
-				iclient := i
-				client := fmt.Sprintf("client %d/%d", iagent, iclient)
-				s.connectClientToAgent(client, publicId, data, agentRes)
+				wg.Add(1)
+				go func() {
+					defer wg.Done()
+					iclient := i
+					client := fmt.Sprintf("client %d/%d", iagent, iclient)
+					s.connectClientToAgent(client, publicId, data, agentRes)
+				}()
 			}
 		}()
 	}
@@ -234,23 +233,71 @@ func (s *AdminTestSuite) Test_MultipleAgentsAndClients() {
 
 }
 
-func (s *AdminTestSuite) connectClientToAgent(client string, publicId string, data string, agentRes AgentRegisterResult) {
+func (s *AdminTestSuite) connectClientToAgent(
+	client string, publicId string, data string, agentRes AgentRegisterResult) (*clientConnection, error) {
 	serverToClientRW, clientToServerRW := s.createPipe()
+
+	// TODO refactoring
+	// - TestAgentListener should run in a separate go routine
+	//   Started by TestAgentSuite.
+	//
+	// TODO split up:
+	// 1. server: connects to agent, agent: listens for connections
+	//    output:   server: clientConnection with a.o. clientId
+	//              agent: listener
+	// 2. communication check:
+	//    server: use yamux connection to send message
+	//    agent: retrieve connection from listener based on client id from clientConnection
+	//           -> yamux connection
+	//    exchange messages in two directions.
+	// 3. birectional communication
+	//    full communication from client to agent through the converge server.
+
+	// Connect server to agent
 	res := testsupport.RunAndWait(
 		&s.Suite,
+		// Server: agent is already listening and accepts all connections and stores them based on clientId
 		func() any {
-			return s.connectClient(publicId, serverToClientRW, data)
-		},
-		func() any {
-			return s.clientConnection(client, agentRes, data)
+			return s.connectClient(publicId, serverToClientRW)
 		})
 
 	// bidirectional communication check
 	clientConn := res[0].(*clientConnection)
-	agentToServerYamux := res[1].(io.ReadWriter)
+	s.NotNil(clientConn)
+	if clientConn == nil {
+		return nil, errors.New("Client connection is nil")
+	}
+	clientId := clientConn.Info.ClientId
+
+	// Retrieve the agent side connection for this client that was setup by the server
+	agentToServerYamux, err := s.clientConnection(clientId, agentRes.listener)
+	s.Nil(err)
+	if err != nil {
+		return nil, err
+	}
+
+	log.Println("Got agentToServerYamux")
+	serverToAgentYamux := clientConn.agentConnection
+
+	// Now first test the communication from server to agent over the just established connection
+	testsupport.RunAndWait(
+		&s.Suite,
+		func() any {
+			s.sendYamuxMsgServerToAgent(serverToAgentYamux, data)
+			return nil
+		},
+		func() any {
+			s.receiveYamuxMsgServerToAgent(agentToServerYamux, data)
+			return nil
+		})
+
+	// Synchronize data between client and agent through the server
 	go clientConn.Synchronize()
 	msg := fmt.Sprintf("end-to-end %s", client)
+	// verify bidirectional communication
 	s.bidirectionalConnectionCheck(msg, clientToServerRW, agentToServerYamux)
+
+	return clientConn, nil
 }
 
 func (s *AdminTestSuite) bidirectionalConnectionCheck(msg string, clientToServerRW io.ReadWriteCloser, agentToServerYamux io.ReadWriter) {
@@ -314,30 +361,41 @@ func (s *AdminTestSuite) agentRegistration(agentToServerRW io.ReadWriteCloser) A
 		return AgentRegisterResult{registration: agentRegistration, err: err}
 	}
 	s.NotNil(commChannel)
-	return AgentRegisterResult{registration: agentRegistration, commChannel: commChannel, err: nil}
+
+	baseListener := comms.NewAgentListener(commChannel.Session)
+	listener := NewTestListener(s.ctx, baseListener)
+
+	return AgentRegisterResult{
+		registration: agentRegistration,
+		commChannel:  commChannel,
+		listener:     listener,
+		err:          nil,
+	}
 }
 
-func (s *AdminTestSuite) connectClient(publicId string, serverToClientRW io.ReadWriteCloser, data string) any {
+func (s *AdminTestSuite) connectClient(publicId string, serverToClientRW io.ReadWriteCloser) any {
 	// server
 	clientConn, err := s.admin.AddClient(models.RendezVousId(publicId),
 		iowrappers.NewSimpleReadWriteAddrCloser(serverToClientRW, testsupport.DummyRemoteAddr("remoteaddr")))
 	s.Nil(err)
-	// Connection to agent over yamux
-	serverToAgentYamux := clientConn.agentConnection
-	// test by sending a message to the agent.
-	testsupport.AssertWriteData(&s.Suite, data, serverToAgentYamux)
 	return clientConn
 }
 
-func (s *AdminTestSuite) clientConnection(client string, agentRes AgentRegisterResult, data string) any {
+func (s *AdminTestSuite) clientConnection(clientId models.ClientId, listener *TestAgentListener) (net.Conn, error) {
 	// agent
-	listener := comms.NewAgentListener(agentRes.commChannel.Session)
-	//.Connection from server over yamux
-	agentToServerYamux, err := listener.Accept()
+	log.Printf("clientConnection: Getting connection for %v", clientId)
+	agentToServerYamux, err := listener.getConnection(clientId)
+	log.Printf("clientConnection: Got connection %v for client %v", agentToServerYamux, clientId)
 	s.Nil(err)
-	log.Printf("RESULT FROM ACCEPT %s %v", client, agentToServerYamux)
-	// Test by receiving a message from the server
-	testsupport.AssertReadData(&s.Suite, data, agentToServerYamux)
-	log.Printf("Asserted on read data: %v", data)
-	return agentToServerYamux
+	return agentToServerYamux, err
+}
+
+func (s *AdminTestSuite) sendYamuxMsgServerToAgent(serverToAgentYamux io.Writer, data string) {
+	// server
+	testsupport.AssertWriteData(&s.Suite, data, serverToAgentYamux)
+}
+
+func (s *AdminTestSuite) receiveYamuxMsgServerToAgent(agentToServerYamux io.Reader, data string) {
+	// agent
+	testsupport.AssertReadData(&s.Suite, data, agentToServerYamux)
 }
diff --git a/pkg/server/admin/listener_test.go b/pkg/server/admin/listener_test.go
new file mode 100644
index 0000000..65eb846
--- /dev/null
+++ b/pkg/server/admin/listener_test.go
@@ -0,0 +1,87 @@
+package admin
+
+import (
+	"context"
+	"errors"
+	"git.wamblee.org/converge/pkg/models"
+	"log"
+	"net"
+	"sync"
+)
+
+// Extension of agentlistener for testing. It can accept all connections and puts them into a map based
+// on clientId after which a client can retrieve the accepted connection based on client id.
+
+type TestAgentListener struct {
+	net.Listener
+
+	ctx         context.Context
+	mutex       sync.Mutex
+	cond        *sync.Cond
+	connections map[models.ClientId]net.Conn
+}
+
+func NewTestListener(ctx context.Context, listener net.Listener) *TestAgentListener {
+	res := &TestAgentListener{
+		ctx:         ctx,
+		Listener:    listener,
+		mutex:       sync.Mutex{},
+		connections: make(map[models.ClientId]net.Conn),
+	}
+	res.cond = sync.NewCond(&res.mutex)
+
+	go func() {
+		for {
+			conn, err := res.Accept()
+			log.Printf("testlistener: Got connection %v %v", conn, err)
+			if err != nil {
+				return
+			}
+		}
+	}()
+
+	go func() {
+		select {
+		case <-res.ctx.Done():
+			res.mutex.Lock()
+			res.cond.Broadcast()
+			res.mutex.Unlock()
+		}
+	}()
+	return res
+}
+
+func (l *TestAgentListener) Accept() (net.Conn, error) {
+	conn, err := l.Listener.Accept()
+	if err != nil {
+		return nil, err
+	}
+	clientId := models.ClientId(conn.LocalAddr().String())
+	log.Printf("testlistener: Storing connection %v %v", clientId, conn)
+	l.mutex.Lock()
+	defer l.mutex.Unlock()
+	l.connections[clientId] = conn
+	log.Printf("testlistener: broadcasting %v", clientId)
+	l.cond.Broadcast()
+	return conn, err
+}
+
+func (l *TestAgentListener) getConnection(clientId models.ClientId) (net.Conn, error) {
+	l.mutex.Lock()
+	defer l.mutex.Unlock()
+	// We need to check the condition before the first cond.wait as well. Otherwise, a broadcast sent
+	// at this point in time will not be caught, and if there are no further broadcasts happening, then
+	// the code will hang her.
+	for ok := l.connections[clientId] != nil; !ok; ok = l.connections[clientId] != nil {
+		log.Println("Listener cond wait")
+		l.cond.Wait()
+		log.Println("Listener awoken")
+		select {
+		case <-l.ctx.Done():
+			return nil, errors.New("Listenere terminated because context canceled")
+		default:
+		}
+	}
+	log.Printf("Returning connection %v %v", clientId, l.connections[clientId])
+	return l.connections[clientId], nil
+}
diff --git a/pkg/testsupport/utils.go b/pkg/testsupport/utils.go
index a50d445..c7ec120 100644
--- a/pkg/testsupport/utils.go
+++ b/pkg/testsupport/utils.go
@@ -8,6 +8,7 @@ import (
 	"log"
 	"net/http"
 	"os"
+	"runtime"
 	_ "runtime/pprof"
 	"sync"
 	"time"
@@ -80,9 +81,18 @@ func AssertWriteData(s *suite.Suite, data string, writer io.Writer) {
 }
 
 func AssertReadData(s *suite.Suite, data string, reader io.Reader) {
-	buf := make([]byte, len(data)*2)
+	buf := make([]byte, len(data)+1024)
 	n, err := reader.Read(buf)
 	s.Nil(err)
 	s.Equal(len(data), n)
 	s.Equal(data, string(buf[:n]))
 }
+
+func PrintStackTraces() {
+	buf := make([]byte, 100000)
+	runtime.Stack(buf, true)
+	log.Println("STACKTRACE")
+	log.Println("")
+	log.Println(string(buf))
+	log.Println("")
+}