357 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			357 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package matchmaker
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"git.wamblee.org/converge/pkg/comms"
 | 
						|
	"git.wamblee.org/converge/pkg/models"
 | 
						|
	"git.wamblee.org/converge/pkg/support/iowrappers"
 | 
						|
	"git.wamblee.org/converge/pkg/testsupport"
 | 
						|
	"github.com/stretchr/testify/suite"
 | 
						|
	"go.uber.org/goleak"
 | 
						|
	"io"
 | 
						|
	"log"
 | 
						|
	"net/http"
 | 
						|
	"os"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
)
 | 
						|
 | 
						|
type MatchMakerTestSuite struct {
 | 
						|
	suite.Suite
 | 
						|
 | 
						|
	ctx         context.Context
 | 
						|
	cancelFunc  context.CancelFunc
 | 
						|
	pprofServer *http.Server
 | 
						|
 | 
						|
	notifier   *TestNotifier
 | 
						|
	matchMaker *MatchMaker
 | 
						|
}
 | 
						|
 | 
						|
type TestNotifier struct {
 | 
						|
	// last reported state
 | 
						|
	state *models.State
 | 
						|
}
 | 
						|
 | 
						|
func (notifier *TestNotifier) Publish(state *models.State) {
 | 
						|
	notifier.state = state
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) SetupSuite() {
 | 
						|
	s.pprofServer = testsupport.StartPprof("")
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) TearDownSuite() {
 | 
						|
	testsupport.StopPprof(s.ctx, s.pprofServer)
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) SetupTest() {
 | 
						|
	ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second)
 | 
						|
	s.ctx = ctx
 | 
						|
	s.cancelFunc = cancelFunc
 | 
						|
 | 
						|
	s.notifier = &TestNotifier{}
 | 
						|
	s.matchMaker = NewMatchMaker(s.notifier)
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) TearDownTest() {
 | 
						|
	s.matchMaker.Close()
 | 
						|
	s.cancelFunc()
 | 
						|
	goleak.VerifyNone(s.T())
 | 
						|
}
 | 
						|
 | 
						|
func TestMatchMakerTestSuite(t *testing.T) {
 | 
						|
	suite.Run(t, &MatchMakerTestSuite{})
 | 
						|
}
 | 
						|
 | 
						|
type TestAgent struct {
 | 
						|
	agentSideConn  io.ReadWriteCloser
 | 
						|
	serverSIdeConn io.ReadWriteCloser
 | 
						|
 | 
						|
	agentRegistration comms.AgentRegistration
 | 
						|
	commChannel       comms.CommChannel
 | 
						|
	listener          *testsupport.TestAgentListener
 | 
						|
}
 | 
						|
 | 
						|
func NewTestAgent(ctx context.Context) *TestAgent {
 | 
						|
	res := TestAgent{}
 | 
						|
	a, s := testsupport.CreatePipe(ctx)
 | 
						|
	res.agentSideConn = a
 | 
						|
	res.serverSIdeConn = s
 | 
						|
	return &res
 | 
						|
}
 | 
						|
 | 
						|
func (agent *TestAgent) Disconnect() {
 | 
						|
	agent.agentSideConn.Close()
 | 
						|
}
 | 
						|
 | 
						|
func (agent *TestAgent) Initialize(s *MatchMakerTestSuite) (comms.ServerInfo, error) {
 | 
						|
	return comms.AgentInitialization(agent.agentSideConn, comms.NewEnvironmentInfo("bash"))
 | 
						|
}
 | 
						|
 | 
						|
func (agent *TestAgent) Register(s *MatchMakerTestSuite) error {
 | 
						|
	agentRegistration, err := comms.ReceiveRegistrationMessage(agent.agentSideConn)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	agent.agentRegistration = agentRegistration
 | 
						|
	commChannel, err := comms.NewCommChannel(comms.Agent, agent.agentSideConn)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	s.NotNil(commChannel)
 | 
						|
	agent.commChannel = commChannel
 | 
						|
 | 
						|
	baseListener := comms.NewAgentListener(commChannel.Session)
 | 
						|
	agent.listener = testsupport.NewTestListener(s.ctx, baseListener)
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
type TestClient struct {
 | 
						|
	clientSideConn io.ReadWriteCloser
 | 
						|
	serverSIdeConn iowrappers.ReadWriteAddrCloser
 | 
						|
	publicId       models.RendezVousId
 | 
						|
	clientId       models.ClientId
 | 
						|
 | 
						|
	// for wsproxy mode
 | 
						|
	serverProtocol       comms.ProtocolVersion
 | 
						|
	clientConnectionInfo comms.ClientConnectionInfo
 | 
						|
}
 | 
						|
 | 
						|
func NewTestClient(ctx context.Context) *TestClient {
 | 
						|
	a, b := testsupport.CreatePipe(ctx)
 | 
						|
	res := TestClient{
 | 
						|
		clientSideConn: a,
 | 
						|
		serverSIdeConn: iowrappers.NewSimpleReadWriteAddrCloser(b,
 | 
						|
			testsupport.DummyRemoteAddr("remoteaddr")),
 | 
						|
	}
 | 
						|
	return &res
 | 
						|
}
 | 
						|
 | 
						|
func (c *TestClient) Disconnect() {
 | 
						|
	c.clientSideConn.Close()
 | 
						|
}
 | 
						|
 | 
						|
func (c *TestClient) WsproxyInit() error {
 | 
						|
	channel := comms.NewGOBChannel(c.clientSideConn)
 | 
						|
	serverProtocol, err := comms.ReceiveWithTimeout[comms.ProtocolVersion](channel)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	c.serverProtocol = serverProtocol
 | 
						|
 | 
						|
	clientConnectionInfo, err := comms.ReceiveWithTimeout[comms.ClientConnectionInfo](channel)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	c.clientConnectionInfo = clientConnectionInfo
 | 
						|
	err = comms.SendWithTimeout(channel, comms.NewEnvironmentInfo(os.Getenv("SHELL")))
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) Test_newMatchMaker() {
 | 
						|
	s.checkState(0, 0)
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) Test_singleAgent() {
 | 
						|
	publicId := models.RendezVousId("abc")
 | 
						|
	agent := NewTestAgent(s.ctx)
 | 
						|
	waitForAgentFunc := s.registerAgent(publicId, agent)
 | 
						|
 | 
						|
	s.checkState(1, 0)
 | 
						|
 | 
						|
	// required for connection loss detection
 | 
						|
	go waitForAgentFunc()
 | 
						|
 | 
						|
	agent.Disconnect()
 | 
						|
	s.checkState(0, 0)
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) Test_singleAgentAndClient() {
 | 
						|
	publicId := models.RendezVousId("abc")
 | 
						|
	agent := NewTestAgent(s.ctx)
 | 
						|
 | 
						|
	waitForAgentFunc := s.registerAgent(publicId, agent)
 | 
						|
	go waitForAgentFunc()
 | 
						|
 | 
						|
	client, err := s.connectClient(publicId, false)
 | 
						|
	s.Nil(err)
 | 
						|
 | 
						|
	s.checkState(1, 1)
 | 
						|
 | 
						|
	agentClientSideConn, err := agent.listener.GetConnection(string(client.clientId))
 | 
						|
	log.Printf("Agent side conn %v", agentClientSideConn)
 | 
						|
	s.Nil(err)
 | 
						|
	testsupport.BidirectionalConnectionCheck(
 | 
						|
		&s.Suite, "testmsg",
 | 
						|
		client.clientSideConn,
 | 
						|
		agentClientSideConn)
 | 
						|
 | 
						|
	client.Disconnect()
 | 
						|
	// It is the agents choice to exit> The test agent does not exit by default when
 | 
						|
	// there are no more connections.
 | 
						|
	s.checkState(1, 0)
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) Test_ConnectCLientToUnknownAgent() {
 | 
						|
	publicId := models.RendezVousId("abc")
 | 
						|
 | 
						|
	_, err := s.connectClient(publicId, false)
 | 
						|
	s.NotNil(err)
 | 
						|
	s.checkState(0, 0)
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) Test_multipleAgentsAndClients() {
 | 
						|
	agents := []string{"abc", "def", "ghi"}
 | 
						|
	clients := map[string]int{"abc": 3, "def": 2, "ghi": 5}
 | 
						|
	log.Printf("agents %v, clients %v", agents, clients)
 | 
						|
	testAgents := make(map[string]*TestAgent)
 | 
						|
	for _, publicId := range agents {
 | 
						|
		agent := NewTestAgent(s.ctx)
 | 
						|
		testAgents[publicId] = agent
 | 
						|
		waitForAgentFunc := s.registerAgent(models.RendezVousId(publicId), agent)
 | 
						|
		s.Require().NotNil(waitForAgentFunc)
 | 
						|
		go waitForAgentFunc()
 | 
						|
	}
 | 
						|
	testClients := make([]*TestClient, 0)
 | 
						|
	for publicId, nclients := range clients {
 | 
						|
		for range nclients {
 | 
						|
			client, err := s.connectClient(models.RendezVousId(publicId), false)
 | 
						|
			s.Require().Nil(err)
 | 
						|
			testClients = append(testClients, client)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	wg := sync.WaitGroup{}
 | 
						|
 | 
						|
	// bidirectional connection test
 | 
						|
	for _, testClient := range testClients {
 | 
						|
		wg.Add(1)
 | 
						|
		go func() {
 | 
						|
			defer wg.Done()
 | 
						|
			agent := testAgents[string(testClient.publicId)]
 | 
						|
			agentClientSideConn, err := agent.listener.GetConnection(string(testClient.clientId))
 | 
						|
			s.Nil(err)
 | 
						|
			log.Printf("Testing bi-directional commication client %v agent %v", testClient.clientId, testClient.publicId)
 | 
						|
			testsupport.BidirectionalConnectionCheck(
 | 
						|
				&s.Suite, "testmsg"+string(testClient.clientId),
 | 
						|
				testClient.clientSideConn,
 | 
						|
				agentClientSideConn)
 | 
						|
		}()
 | 
						|
	}
 | 
						|
	wg.Wait()
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) connectClient(publicId models.RendezVousId, wsproxyMode bool) (*TestClient, error) {
 | 
						|
	client := NewTestClient(s.ctx)
 | 
						|
	var clientId models.ClientId
 | 
						|
	res := testsupport.RunAndWait(
 | 
						|
		&s.Suite,
 | 
						|
		func() any {
 | 
						|
			//server
 | 
						|
			clientIdCreated, synchronizer, err := s.matchMaker.Connect(wsproxyMode, publicId, client.serverSIdeConn)
 | 
						|
			clientId = clientIdCreated
 | 
						|
			if err == nil {
 | 
						|
				log.Println("test: synchronizing streams.")
 | 
						|
				go synchronizer()
 | 
						|
			}
 | 
						|
			return err
 | 
						|
		},
 | 
						|
		func() any {
 | 
						|
			if wsproxyMode {
 | 
						|
				client.WsproxyInit()
 | 
						|
			}
 | 
						|
			return nil
 | 
						|
		})
 | 
						|
 | 
						|
	if res[0] != nil {
 | 
						|
		return client, res[0].(error)
 | 
						|
	}
 | 
						|
	client.publicId = publicId
 | 
						|
	client.clientId = clientId
 | 
						|
	return client, nil
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) checkState(nAgents int, nClients int) {
 | 
						|
	s.True(testsupport.CheckCondition(s.ctx, func() bool {
 | 
						|
		return nAgents == len(s.notifier.state.Agents)
 | 
						|
	}))
 | 
						|
	s.True(testsupport.CheckCondition(s.ctx, func() bool {
 | 
						|
		return nClients == len(s.notifier.state.Clients)
 | 
						|
	}))
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) registerAgent(publicId models.RendezVousId, agent *TestAgent) WaitForAgentFunc {
 | 
						|
	res := testsupport.RunAndWait(
 | 
						|
		&s.Suite,
 | 
						|
		func() any {
 | 
						|
			// ignore waitFunc for now.
 | 
						|
			waitFunc, err := s.matchMaker.Register(publicId, agent.serverSIdeConn)
 | 
						|
			s.Nil(err)
 | 
						|
			log.Printf("MatchMaskerTest: Agent registered by server")
 | 
						|
			return waitFunc
 | 
						|
		},
 | 
						|
		func() any {
 | 
						|
			_, err := agent.Initialize(s)
 | 
						|
			if err != nil {
 | 
						|
				s.Nil(err)
 | 
						|
				return nil
 | 
						|
			}
 | 
						|
			err = agent.Register(s)
 | 
						|
			if err != nil {
 | 
						|
				s.Nil(err)
 | 
						|
				return nil
 | 
						|
			}
 | 
						|
			log.Println("MatchMakerTest: Agent connected to server")
 | 
						|
			return nil
 | 
						|
		})
 | 
						|
 | 
						|
	if res[0] == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	return res[0].(WaitForAgentFunc)
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) Test_ConnectWsproxyMode() {
 | 
						|
	publicId := models.RendezVousId("abc")
 | 
						|
	agent := NewTestAgent(s.ctx)
 | 
						|
 | 
						|
	waitForAgentFunc := s.registerAgent(publicId, agent)
 | 
						|
	go waitForAgentFunc()
 | 
						|
 | 
						|
	client, err := s.connectClient(publicId, true)
 | 
						|
	s.Nil(err)
 | 
						|
 | 
						|
	s.checkState(1, 1)
 | 
						|
 | 
						|
	agentClientSideConn, err := agent.listener.GetConnection(string(client.clientId))
 | 
						|
	log.Printf("Agent side conn %v", agentClientSideConn)
 | 
						|
	s.Nil(err)
 | 
						|
	testsupport.BidirectionalConnectionCheck(
 | 
						|
		&s.Suite, "testmsg",
 | 
						|
		client.clientSideConn,
 | 
						|
		agentClientSideConn)
 | 
						|
 | 
						|
	s.True(client.clientConnectionInfo.Ok)
 | 
						|
 | 
						|
	client.Disconnect()
 | 
						|
	// It is the agents choice to exit> The test agent does not exit by default when
 | 
						|
	// there are no more connections.
 | 
						|
	s.checkState(1, 0)
 | 
						|
}
 | 
						|
 | 
						|
func (s *MatchMakerTestSuite) Test_ConnectWsproxyModeAgentNotFound() {
 | 
						|
	publicId := models.RendezVousId("abc")
 | 
						|
 | 
						|
	client, err := s.connectClient(publicId, true)
 | 
						|
	s.NotNil(err)
 | 
						|
	s.True(strings.Contains(err.Error(), "No agent found for rendez-vous id"))
 | 
						|
	s.False(client.clientConnectionInfo.Ok)
 | 
						|
 | 
						|
	s.checkState(0, 0)
 | 
						|
}
 |