diff --git a/pkg/comms/agentlistener.go b/pkg/comms/agentlistener.go index 0ddaabc..71f7db8 100644 --- a/pkg/comms/agentlistener.go +++ b/pkg/comms/agentlistener.go @@ -2,7 +2,6 @@ package comms import ( "git.wamblee.org/converge/pkg/support/websocketutil" - "log" "net" ) @@ -43,7 +42,6 @@ func (listener AgentListener) Accept() (net.Conn, error) { return nil, err } conn = NewLocalAddrHackConn(conn, clientId) - log.Printf("ACCEPT %v %v", clientId, conn) return conn, nil } diff --git a/pkg/server/admin/admin_test.go b/pkg/server/admin/admin_test.go index e9d1ed0..74579ab 100644 --- a/pkg/server/admin/admin_test.go +++ b/pkg/server/admin/admin_test.go @@ -3,14 +3,17 @@ package admin import ( "context" "crypto/rand" + "errors" "fmt" "git.wamblee.org/converge/pkg/comms" "git.wamblee.org/converge/pkg/models" "git.wamblee.org/converge/pkg/support/iowrappers" "git.wamblee.org/converge/pkg/testsupport" "github.com/stretchr/testify/suite" + "go.uber.org/goleak" "io" "log" + "net" "net/http" "strings" "sync" @@ -61,7 +64,7 @@ func (s *AdminTestSuite) SetupTest() { func (s *AdminTestSuite) TearDownTest() { s.admin.Close() s.cancelFunc() - //goleak.VerifyNone(s.T()) + goleak.VerifyNone(s.T()) } func TestAdminTestSuite(t *testing.T) { @@ -73,12 +76,6 @@ type AddAgentResult struct { err error } -type AgentRegisterResult struct { - registration comms.AgentRegistration - commChannel comms.CommChannel - err error -} - func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId string) (AddAgentResult, AgentRegisterResult) { agentToServerRW, serverToAgentRW := s.createPipe() res := testsupport.RunAndWait( @@ -102,6 +99,13 @@ func (s *AdminTestSuite) agentRegisters(requestedPublicId, assignedPublicId stri return res[0].(AddAgentResult), res[1].(AgentRegisterResult) } +type AgentRegisterResult struct { + registration comms.AgentRegistration + commChannel comms.CommChannel + listener *TestAgentListener + err error +} + func (s *AdminTestSuite) Test_AgentRegisters() { publicId := "abc" res, _ := s.agentRegisters(publicId, publicId) @@ -162,29 +166,18 @@ func (s *AdminTestSuite) Test_agentDuplicateId() { s.False(agentSideResult.registration.Ok) } -func (s *AdminTestSuite) Test_connectClient() { +func (s *AdminTestSuite) Test_connectClient() error { publicId := "abc" serverRes, agentRes := s.agentRegisters(publicId, publicId) s.Nil(serverRes.err) s.Nil(agentRes.err) - serverToClientRW, clientToServerRW := s.createPipe() - data := "connect client test msg" - res := testsupport.RunAndWait( - &s.Suite, - func() any { - return s.connectClient(publicId, serverToClientRW, data) - }, - func() any { - return s.clientConnection("0", agentRes, data) - }) - - // bidirectional communication check - clientConn := res[0].(*clientConnection) - agentToServerYamux := res[1].(io.ReadWriter) - go clientConn.Synchronize() - s.bidirectionalConnectionCheck("mymessage", clientToServerRW, agentToServerYamux) + clientConn, err := s.connectClientToAgent("singleclient", publicId, data, agentRes) + s.Nil(err) + if err != nil { + return err + } // verify state state := s.admin.CreateNotifification() @@ -194,18 +187,20 @@ func (s *AdminTestSuite) Test_connectClient() { // removing the client will close all connections, we test this by writing to the connections // after removing the client. - s.admin.RemoveClient(clientConn) + err = s.admin.RemoveClient(clientConn) + s.Nil(err) buf := make([]byte, 10) - _, err := clientConn.clientConnection.Write(buf) + _, err = clientConn.clientConnection.Write(buf) s.NotNil(err) s.True(strings.Contains(err.Error(), "closed")) _, err = clientConn.agentConnection.Write(buf) s.NotNil(err) s.True(strings.Contains(err.Error(), "closed")) + return nil } func (s *AdminTestSuite) Test_MultipleAgentsAndClients() { - clientCounts := []int{23, 5, 3, 1} + clientCounts := []int{10, 5, 37, 1, 29} wg := sync.WaitGroup{} for iagent, clientCount := range clientCounts { @@ -224,9 +219,13 @@ func (s *AdminTestSuite) Test_MultipleAgentsAndClients() { // created in a map base on client id. The client can then retrieve the // connection based on the client id and should also wait until the // connection is available. - iclient := i - client := fmt.Sprintf("client %d/%d", iagent, iclient) - s.connectClientToAgent(client, publicId, data, agentRes) + wg.Add(1) + go func() { + defer wg.Done() + iclient := i + client := fmt.Sprintf("client %d/%d", iagent, iclient) + s.connectClientToAgent(client, publicId, data, agentRes) + }() } }() } @@ -234,23 +233,71 @@ func (s *AdminTestSuite) Test_MultipleAgentsAndClients() { } -func (s *AdminTestSuite) connectClientToAgent(client string, publicId string, data string, agentRes AgentRegisterResult) { +func (s *AdminTestSuite) connectClientToAgent( + client string, publicId string, data string, agentRes AgentRegisterResult) (*clientConnection, error) { serverToClientRW, clientToServerRW := s.createPipe() + + // TODO refactoring + // - TestAgentListener should run in a separate go routine + // Started by TestAgentSuite. + // + // TODO split up: + // 1. server: connects to agent, agent: listens for connections + // output: server: clientConnection with a.o. clientId + // agent: listener + // 2. communication check: + // server: use yamux connection to send message + // agent: retrieve connection from listener based on client id from clientConnection + // -> yamux connection + // exchange messages in two directions. + // 3. birectional communication + // full communication from client to agent through the converge server. + + // Connect server to agent res := testsupport.RunAndWait( &s.Suite, + // Server: agent is already listening and accepts all connections and stores them based on clientId func() any { - return s.connectClient(publicId, serverToClientRW, data) - }, - func() any { - return s.clientConnection(client, agentRes, data) + return s.connectClient(publicId, serverToClientRW) }) // bidirectional communication check clientConn := res[0].(*clientConnection) - agentToServerYamux := res[1].(io.ReadWriter) + s.NotNil(clientConn) + if clientConn == nil { + return nil, errors.New("Client connection is nil") + } + clientId := clientConn.Info.ClientId + + // Retrieve the agent side connection for this client that was setup by the server + agentToServerYamux, err := s.clientConnection(clientId, agentRes.listener) + s.Nil(err) + if err != nil { + return nil, err + } + + log.Println("Got agentToServerYamux") + serverToAgentYamux := clientConn.agentConnection + + // Now first test the communication from server to agent over the just established connection + testsupport.RunAndWait( + &s.Suite, + func() any { + s.sendYamuxMsgServerToAgent(serverToAgentYamux, data) + return nil + }, + func() any { + s.receiveYamuxMsgServerToAgent(agentToServerYamux, data) + return nil + }) + + // Synchronize data between client and agent through the server go clientConn.Synchronize() msg := fmt.Sprintf("end-to-end %s", client) + // verify bidirectional communication s.bidirectionalConnectionCheck(msg, clientToServerRW, agentToServerYamux) + + return clientConn, nil } func (s *AdminTestSuite) bidirectionalConnectionCheck(msg string, clientToServerRW io.ReadWriteCloser, agentToServerYamux io.ReadWriter) { @@ -314,30 +361,41 @@ func (s *AdminTestSuite) agentRegistration(agentToServerRW io.ReadWriteCloser) A return AgentRegisterResult{registration: agentRegistration, err: err} } s.NotNil(commChannel) - return AgentRegisterResult{registration: agentRegistration, commChannel: commChannel, err: nil} + + baseListener := comms.NewAgentListener(commChannel.Session) + listener := NewTestListener(s.ctx, baseListener) + + return AgentRegisterResult{ + registration: agentRegistration, + commChannel: commChannel, + listener: listener, + err: nil, + } } -func (s *AdminTestSuite) connectClient(publicId string, serverToClientRW io.ReadWriteCloser, data string) any { +func (s *AdminTestSuite) connectClient(publicId string, serverToClientRW io.ReadWriteCloser) any { // server clientConn, err := s.admin.AddClient(models.RendezVousId(publicId), iowrappers.NewSimpleReadWriteAddrCloser(serverToClientRW, testsupport.DummyRemoteAddr("remoteaddr"))) s.Nil(err) - // Connection to agent over yamux - serverToAgentYamux := clientConn.agentConnection - // test by sending a message to the agent. - testsupport.AssertWriteData(&s.Suite, data, serverToAgentYamux) return clientConn } -func (s *AdminTestSuite) clientConnection(client string, agentRes AgentRegisterResult, data string) any { +func (s *AdminTestSuite) clientConnection(clientId models.ClientId, listener *TestAgentListener) (net.Conn, error) { // agent - listener := comms.NewAgentListener(agentRes.commChannel.Session) - //.Connection from server over yamux - agentToServerYamux, err := listener.Accept() + log.Printf("clientConnection: Getting connection for %v", clientId) + agentToServerYamux, err := listener.getConnection(clientId) + log.Printf("clientConnection: Got connection %v for client %v", agentToServerYamux, clientId) s.Nil(err) - log.Printf("RESULT FROM ACCEPT %s %v", client, agentToServerYamux) - // Test by receiving a message from the server - testsupport.AssertReadData(&s.Suite, data, agentToServerYamux) - log.Printf("Asserted on read data: %v", data) - return agentToServerYamux + return agentToServerYamux, err +} + +func (s *AdminTestSuite) sendYamuxMsgServerToAgent(serverToAgentYamux io.Writer, data string) { + // server + testsupport.AssertWriteData(&s.Suite, data, serverToAgentYamux) +} + +func (s *AdminTestSuite) receiveYamuxMsgServerToAgent(agentToServerYamux io.Reader, data string) { + // agent + testsupport.AssertReadData(&s.Suite, data, agentToServerYamux) } diff --git a/pkg/server/admin/listener_test.go b/pkg/server/admin/listener_test.go new file mode 100644 index 0000000..65eb846 --- /dev/null +++ b/pkg/server/admin/listener_test.go @@ -0,0 +1,87 @@ +package admin + +import ( + "context" + "errors" + "git.wamblee.org/converge/pkg/models" + "log" + "net" + "sync" +) + +// Extension of agentlistener for testing. It can accept all connections and puts them into a map based +// on clientId after which a client can retrieve the accepted connection based on client id. + +type TestAgentListener struct { + net.Listener + + ctx context.Context + mutex sync.Mutex + cond *sync.Cond + connections map[models.ClientId]net.Conn +} + +func NewTestListener(ctx context.Context, listener net.Listener) *TestAgentListener { + res := &TestAgentListener{ + ctx: ctx, + Listener: listener, + mutex: sync.Mutex{}, + connections: make(map[models.ClientId]net.Conn), + } + res.cond = sync.NewCond(&res.mutex) + + go func() { + for { + conn, err := res.Accept() + log.Printf("testlistener: Got connection %v %v", conn, err) + if err != nil { + return + } + } + }() + + go func() { + select { + case <-res.ctx.Done(): + res.mutex.Lock() + res.cond.Broadcast() + res.mutex.Unlock() + } + }() + return res +} + +func (l *TestAgentListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + clientId := models.ClientId(conn.LocalAddr().String()) + log.Printf("testlistener: Storing connection %v %v", clientId, conn) + l.mutex.Lock() + defer l.mutex.Unlock() + l.connections[clientId] = conn + log.Printf("testlistener: broadcasting %v", clientId) + l.cond.Broadcast() + return conn, err +} + +func (l *TestAgentListener) getConnection(clientId models.ClientId) (net.Conn, error) { + l.mutex.Lock() + defer l.mutex.Unlock() + // We need to check the condition before the first cond.wait as well. Otherwise, a broadcast sent + // at this point in time will not be caught, and if there are no further broadcasts happening, then + // the code will hang her. + for ok := l.connections[clientId] != nil; !ok; ok = l.connections[clientId] != nil { + log.Println("Listener cond wait") + l.cond.Wait() + log.Println("Listener awoken") + select { + case <-l.ctx.Done(): + return nil, errors.New("Listenere terminated because context canceled") + default: + } + } + log.Printf("Returning connection %v %v", clientId, l.connections[clientId]) + return l.connections[clientId], nil +} diff --git a/pkg/testsupport/utils.go b/pkg/testsupport/utils.go index a50d445..c7ec120 100644 --- a/pkg/testsupport/utils.go +++ b/pkg/testsupport/utils.go @@ -8,6 +8,7 @@ import ( "log" "net/http" "os" + "runtime" _ "runtime/pprof" "sync" "time" @@ -80,9 +81,18 @@ func AssertWriteData(s *suite.Suite, data string, writer io.Writer) { } func AssertReadData(s *suite.Suite, data string, reader io.Reader) { - buf := make([]byte, len(data)*2) + buf := make([]byte, len(data)+1024) n, err := reader.Read(buf) s.Nil(err) s.Equal(len(data), n) s.Equal(data, string(buf[:n])) } + +func PrintStackTraces() { + buf := make([]byte, 100000) + runtime.Stack(buf, true) + log.Println("STACKTRACE") + log.Println("") + log.Println(string(buf)) + log.Println("") +}