converge/pkg/server/matchmaker/matchmaker_test.go

357 lines
8.9 KiB
Go

package matchmaker
import (
"context"
"git.wamblee.org/converge/pkg/comms"
"git.wamblee.org/converge/pkg/models"
"git.wamblee.org/converge/pkg/support/ioutils"
"git.wamblee.org/converge/pkg/testsupport"
"github.com/stretchr/testify/suite"
"go.uber.org/goleak"
"io"
"log"
"net/http"
"os"
"strings"
"sync"
"testing"
"time"
)
type MatchMakerTestSuite struct {
suite.Suite
ctx context.Context
cancelFunc context.CancelFunc
pprofServer *http.Server
notifier *TestNotifier
matchMaker *MatchMaker
}
type TestNotifier struct {
// last reported state
state *models.State
}
func (notifier *TestNotifier) Publish(state *models.State) {
notifier.state = state
}
func (s *MatchMakerTestSuite) SetupSuite() {
s.pprofServer = testsupport.StartPprof("")
}
func (s *MatchMakerTestSuite) TearDownSuite() {
testsupport.StopPprof(s.ctx, s.pprofServer)
}
func (s *MatchMakerTestSuite) SetupTest() {
ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second)
s.ctx = ctx
s.cancelFunc = cancelFunc
s.notifier = &TestNotifier{}
s.matchMaker = NewMatchMaker(s.notifier)
}
func (s *MatchMakerTestSuite) TearDownTest() {
s.matchMaker.Close()
s.cancelFunc()
goleak.VerifyNone(s.T())
}
func TestMatchMakerTestSuite(t *testing.T) {
suite.Run(t, &MatchMakerTestSuite{})
}
type TestAgent struct {
agentSideConn io.ReadWriteCloser
serverSIdeConn io.ReadWriteCloser
agentRegistration comms.AgentRegistration
commChannel comms.CommChannel
listener *testsupport.TestAgentListener
}
func NewTestAgent(ctx context.Context) *TestAgent {
res := TestAgent{}
a, s := testsupport.CreatePipe(ctx)
res.agentSideConn = a
res.serverSIdeConn = s
return &res
}
func (agent *TestAgent) Disconnect() {
agent.agentSideConn.Close()
}
func (agent *TestAgent) Initialize(s *MatchMakerTestSuite) (comms.ServerInfo, error) {
return comms.AgentInitialization(agent.agentSideConn, comms.NewEnvironmentInfo("bash"))
}
func (agent *TestAgent) Register(s *MatchMakerTestSuite) error {
agentRegistration, err := comms.ReceiveRegistrationMessage(agent.agentSideConn)
if err != nil {
return err
}
agent.agentRegistration = agentRegistration
commChannel, err := comms.NewCommChannel(comms.Agent, agent.agentSideConn)
if err != nil {
return err
}
s.NotNil(commChannel)
agent.commChannel = commChannel
baseListener := comms.NewAgentListener(commChannel.Session)
agent.listener = testsupport.NewTestListener(s.ctx, baseListener)
return nil
}
type TestClient struct {
clientSideConn io.ReadWriteCloser
serverSIdeConn ioutils.ReadWriteAddrCloser
publicId models.RendezVousId
clientId models.ClientId
// for wsproxy mode
serverProtocol comms.ProtocolVersion
clientConnectionInfo comms.ClientConnectionInfo
}
func NewTestClient(ctx context.Context) *TestClient {
a, b := testsupport.CreatePipe(ctx)
res := TestClient{
clientSideConn: a,
serverSIdeConn: ioutils.NewSimpleReadWriteAddrCloser(b,
testsupport.DummyRemoteAddr("remoteaddr")),
}
return &res
}
func (c *TestClient) Disconnect() {
c.clientSideConn.Close()
}
func (c *TestClient) WsproxyInit() error {
channel := comms.NewGOBChannel(c.clientSideConn)
serverProtocol, err := comms.ReceiveWithTimeout[comms.ProtocolVersion](channel)
if err != nil {
return err
}
c.serverProtocol = serverProtocol
clientConnectionInfo, err := comms.ReceiveWithTimeout[comms.ClientConnectionInfo](channel)
if err != nil {
return err
}
c.clientConnectionInfo = clientConnectionInfo
err = comms.SendWithTimeout(channel, comms.NewEnvironmentInfo(os.Getenv("SHELL")))
if err != nil {
return err
}
return nil
}
func (s *MatchMakerTestSuite) Test_newMatchMaker() {
s.checkState(0, 0)
}
func (s *MatchMakerTestSuite) Test_singleAgent() {
publicId := models.RendezVousId("abc")
agent := NewTestAgent(s.ctx)
waitForAgentFunc := s.registerAgent(publicId, agent)
s.checkState(1, 0)
// required for connection loss detection
go waitForAgentFunc()
agent.Disconnect()
s.checkState(0, 0)
}
func (s *MatchMakerTestSuite) Test_singleAgentAndClient() {
publicId := models.RendezVousId("abc")
agent := NewTestAgent(s.ctx)
waitForAgentFunc := s.registerAgent(publicId, agent)
go waitForAgentFunc()
client, err := s.connectClient(publicId, false)
s.Nil(err)
s.checkState(1, 1)
agentClientSideConn, err := agent.listener.GetConnection(string(client.clientId))
log.Printf("Agent side conn %v", agentClientSideConn)
s.Nil(err)
testsupport.BidirectionalConnectionCheck(
&s.Suite, "testmsg",
client.clientSideConn,
agentClientSideConn)
client.Disconnect()
// It is the agents choice to exit> The test agent does not exit by default when
// there are no more connections.
s.checkState(1, 0)
}
func (s *MatchMakerTestSuite) Test_ConnectCLientToUnknownAgent() {
publicId := models.RendezVousId("abc")
_, err := s.connectClient(publicId, false)
s.NotNil(err)
s.checkState(0, 0)
}
func (s *MatchMakerTestSuite) Test_multipleAgentsAndClients() {
agents := []string{"abc", "def", "ghi"}
clients := map[string]int{"abc": 3, "def": 2, "ghi": 5}
log.Printf("agents %v, clients %v", agents, clients)
testAgents := make(map[string]*TestAgent)
for _, publicId := range agents {
agent := NewTestAgent(s.ctx)
testAgents[publicId] = agent
waitForAgentFunc := s.registerAgent(models.RendezVousId(publicId), agent)
s.Require().NotNil(waitForAgentFunc)
go waitForAgentFunc()
}
testClients := make([]*TestClient, 0)
for publicId, nclients := range clients {
for range nclients {
client, err := s.connectClient(models.RendezVousId(publicId), false)
s.Require().Nil(err)
testClients = append(testClients, client)
}
}
wg := sync.WaitGroup{}
// bidirectional connection test
for _, testClient := range testClients {
wg.Add(1)
go func() {
defer wg.Done()
agent := testAgents[string(testClient.publicId)]
agentClientSideConn, err := agent.listener.GetConnection(string(testClient.clientId))
s.Nil(err)
log.Printf("Testing bi-directional commication client %v agent %v", testClient.clientId, testClient.publicId)
testsupport.BidirectionalConnectionCheck(
&s.Suite, "testmsg"+string(testClient.clientId),
testClient.clientSideConn,
agentClientSideConn)
}()
}
wg.Wait()
}
func (s *MatchMakerTestSuite) connectClient(publicId models.RendezVousId, wsproxyMode bool) (*TestClient, error) {
client := NewTestClient(s.ctx)
var clientId models.ClientId
res := testsupport.RunAndWait(
&s.Suite,
func() any {
//server
clientIdCreated, synchronizer, err := s.matchMaker.Connect(wsproxyMode, publicId, client.serverSIdeConn)
clientId = clientIdCreated
if err == nil {
log.Println("test: synchronizing streams.")
go synchronizer()
}
return err
},
func() any {
if wsproxyMode {
client.WsproxyInit()
}
return nil
})
if res[0] != nil {
return client, res[0].(error)
}
client.publicId = publicId
client.clientId = clientId
return client, nil
}
func (s *MatchMakerTestSuite) checkState(nAgents int, nClients int) {
s.True(testsupport.CheckCondition(s.ctx, func() bool {
return nAgents == len(s.notifier.state.Agents)
}))
s.True(testsupport.CheckCondition(s.ctx, func() bool {
return nClients == len(s.notifier.state.Clients)
}))
}
func (s *MatchMakerTestSuite) registerAgent(publicId models.RendezVousId, agent *TestAgent) WaitForAgentFunc {
res := testsupport.RunAndWait(
&s.Suite,
func() any {
// ignore waitFunc for now.
waitFunc, err := s.matchMaker.Register(publicId, agent.serverSIdeConn)
s.Nil(err)
log.Printf("MatchMaskerTest: Agent registered by server")
return waitFunc
},
func() any {
_, err := agent.Initialize(s)
if err != nil {
s.Nil(err)
return nil
}
err = agent.Register(s)
if err != nil {
s.Nil(err)
return nil
}
log.Println("MatchMakerTest: Agent connected to server")
return nil
})
if res[0] == nil {
return nil
}
return res[0].(WaitForAgentFunc)
}
func (s *MatchMakerTestSuite) Test_ConnectWsproxyMode() {
publicId := models.RendezVousId("abc")
agent := NewTestAgent(s.ctx)
waitForAgentFunc := s.registerAgent(publicId, agent)
go waitForAgentFunc()
client, err := s.connectClient(publicId, true)
s.Nil(err)
s.checkState(1, 1)
agentClientSideConn, err := agent.listener.GetConnection(string(client.clientId))
log.Printf("Agent side conn %v", agentClientSideConn)
s.Nil(err)
testsupport.BidirectionalConnectionCheck(
&s.Suite, "testmsg",
client.clientSideConn,
agentClientSideConn)
s.True(client.clientConnectionInfo.Ok)
client.Disconnect()
// It is the agents choice to exit> The test agent does not exit by default when
// there are no more connections.
s.checkState(1, 0)
}
func (s *MatchMakerTestSuite) Test_ConnectWsproxyModeAgentNotFound() {
publicId := models.RendezVousId("abc")
client, err := s.connectClient(publicId, true)
s.NotNil(err)
s.True(strings.Contains(err.Error(), "No agent found for rendez-vous id"))
s.False(client.clientConnectionInfo.Ok)
s.checkState(0, 0)
}