357 lines
8.9 KiB
Go
357 lines
8.9 KiB
Go
package matchmaker
|
|
|
|
import (
|
|
"context"
|
|
"git.wamblee.org/converge/pkg/comms"
|
|
"git.wamblee.org/converge/pkg/models"
|
|
"git.wamblee.org/converge/pkg/support/iowrappers"
|
|
"git.wamblee.org/converge/pkg/testsupport"
|
|
"github.com/stretchr/testify/suite"
|
|
"go.uber.org/goleak"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type MatchMakerTestSuite struct {
|
|
suite.Suite
|
|
|
|
ctx context.Context
|
|
cancelFunc context.CancelFunc
|
|
pprofServer *http.Server
|
|
|
|
notifier *TestNotifier
|
|
matchMaker *MatchMaker
|
|
}
|
|
|
|
type TestNotifier struct {
|
|
// last reported state
|
|
state *models.State
|
|
}
|
|
|
|
func (notifier *TestNotifier) Publish(state *models.State) {
|
|
notifier.state = state
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) SetupSuite() {
|
|
s.pprofServer = testsupport.StartPprof("")
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) TearDownSuite() {
|
|
testsupport.StopPprof(s.ctx, s.pprofServer)
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) SetupTest() {
|
|
ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second)
|
|
s.ctx = ctx
|
|
s.cancelFunc = cancelFunc
|
|
|
|
s.notifier = &TestNotifier{}
|
|
s.matchMaker = NewMatchMaker(s.notifier)
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) TearDownTest() {
|
|
s.matchMaker.Close()
|
|
s.cancelFunc()
|
|
goleak.VerifyNone(s.T())
|
|
}
|
|
|
|
func TestMatchMakerTestSuite(t *testing.T) {
|
|
suite.Run(t, &MatchMakerTestSuite{})
|
|
}
|
|
|
|
type TestAgent struct {
|
|
agentSideConn io.ReadWriteCloser
|
|
serverSIdeConn io.ReadWriteCloser
|
|
|
|
agentRegistration comms.AgentRegistration
|
|
commChannel comms.CommChannel
|
|
listener *testsupport.TestAgentListener
|
|
}
|
|
|
|
func NewTestAgent(ctx context.Context) *TestAgent {
|
|
res := TestAgent{}
|
|
a, s := testsupport.CreatePipe(ctx)
|
|
res.agentSideConn = a
|
|
res.serverSIdeConn = s
|
|
return &res
|
|
}
|
|
|
|
func (agent *TestAgent) Disconnect() {
|
|
agent.agentSideConn.Close()
|
|
}
|
|
|
|
func (agent *TestAgent) Initialize(s *MatchMakerTestSuite) (comms.ServerInfo, error) {
|
|
return comms.AgentInitialization(agent.agentSideConn, comms.NewEnvironmentInfo("bash"))
|
|
}
|
|
|
|
func (agent *TestAgent) Register(s *MatchMakerTestSuite) error {
|
|
agentRegistration, err := comms.ReceiveRegistrationMessage(agent.agentSideConn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
agent.agentRegistration = agentRegistration
|
|
commChannel, err := comms.NewCommChannel(comms.Agent, agent.agentSideConn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.NotNil(commChannel)
|
|
agent.commChannel = commChannel
|
|
|
|
baseListener := comms.NewAgentListener(commChannel.Session)
|
|
agent.listener = testsupport.NewTestListener(s.ctx, baseListener)
|
|
return nil
|
|
}
|
|
|
|
type TestClient struct {
|
|
clientSideConn io.ReadWriteCloser
|
|
serverSIdeConn iowrappers.ReadWriteAddrCloser
|
|
publicId models.RendezVousId
|
|
clientId models.ClientId
|
|
|
|
// for wsproxy mode
|
|
serverProtocol comms.ProtocolVersion
|
|
clientConnectionInfo comms.ClientConnectionInfo
|
|
}
|
|
|
|
func NewTestClient(ctx context.Context) *TestClient {
|
|
a, b := testsupport.CreatePipe(ctx)
|
|
res := TestClient{
|
|
clientSideConn: a,
|
|
serverSIdeConn: iowrappers.NewSimpleReadWriteAddrCloser(b,
|
|
testsupport.DummyRemoteAddr("remoteaddr")),
|
|
}
|
|
return &res
|
|
}
|
|
|
|
func (c *TestClient) Disconnect() {
|
|
c.clientSideConn.Close()
|
|
}
|
|
|
|
func (c *TestClient) WsproxyInit() error {
|
|
channel := comms.NewGOBChannel(c.clientSideConn)
|
|
serverProtocol, err := comms.ReceiveWithTimeout[comms.ProtocolVersion](channel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.serverProtocol = serverProtocol
|
|
|
|
clientConnectionInfo, err := comms.ReceiveWithTimeout[comms.ClientConnectionInfo](channel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.clientConnectionInfo = clientConnectionInfo
|
|
err = comms.SendWithTimeout(channel, comms.NewEnvironmentInfo(os.Getenv("SHELL")))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) Test_newMatchMaker() {
|
|
s.checkState(0, 0)
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) Test_singleAgent() {
|
|
publicId := models.RendezVousId("abc")
|
|
agent := NewTestAgent(s.ctx)
|
|
waitForAgentFunc := s.registerAgent(publicId, agent)
|
|
|
|
s.checkState(1, 0)
|
|
|
|
// required for connection loss detection
|
|
go waitForAgentFunc()
|
|
|
|
agent.Disconnect()
|
|
s.checkState(0, 0)
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) Test_singleAgentAndClient() {
|
|
publicId := models.RendezVousId("abc")
|
|
agent := NewTestAgent(s.ctx)
|
|
|
|
waitForAgentFunc := s.registerAgent(publicId, agent)
|
|
go waitForAgentFunc()
|
|
|
|
client, err := s.connectClient(publicId, false)
|
|
s.Nil(err)
|
|
|
|
s.checkState(1, 1)
|
|
|
|
agentClientSideConn, err := agent.listener.GetConnection(string(client.clientId))
|
|
log.Printf("Agent side conn %v", agentClientSideConn)
|
|
s.Nil(err)
|
|
testsupport.BidirectionalConnectionCheck(
|
|
&s.Suite, "testmsg",
|
|
client.clientSideConn,
|
|
agentClientSideConn)
|
|
|
|
client.Disconnect()
|
|
// It is the agents choice to exit> The test agent does not exit by default when
|
|
// there are no more connections.
|
|
s.checkState(1, 0)
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) Test_ConnectCLientToUnknownAgent() {
|
|
publicId := models.RendezVousId("abc")
|
|
|
|
_, err := s.connectClient(publicId, false)
|
|
s.NotNil(err)
|
|
s.checkState(0, 0)
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) Test_multipleAgentsAndClients() {
|
|
agents := []string{"abc", "def", "ghi"}
|
|
clients := map[string]int{"abc": 3, "def": 2, "ghi": 5}
|
|
log.Printf("agents %v, clients %v", agents, clients)
|
|
testAgents := make(map[string]*TestAgent)
|
|
for _, publicId := range agents {
|
|
agent := NewTestAgent(s.ctx)
|
|
testAgents[publicId] = agent
|
|
waitForAgentFunc := s.registerAgent(models.RendezVousId(publicId), agent)
|
|
s.Require().NotNil(waitForAgentFunc)
|
|
go waitForAgentFunc()
|
|
}
|
|
testClients := make([]*TestClient, 0)
|
|
for publicId, nclients := range clients {
|
|
for range nclients {
|
|
client, err := s.connectClient(models.RendezVousId(publicId), false)
|
|
s.Require().Nil(err)
|
|
testClients = append(testClients, client)
|
|
}
|
|
}
|
|
|
|
wg := sync.WaitGroup{}
|
|
|
|
// bidirectional connection test
|
|
for _, testClient := range testClients {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
agent := testAgents[string(testClient.publicId)]
|
|
agentClientSideConn, err := agent.listener.GetConnection(string(testClient.clientId))
|
|
s.Nil(err)
|
|
log.Printf("Testing bi-directional commication client %v agent %v", testClient.clientId, testClient.publicId)
|
|
testsupport.BidirectionalConnectionCheck(
|
|
&s.Suite, "testmsg"+string(testClient.clientId),
|
|
testClient.clientSideConn,
|
|
agentClientSideConn)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) connectClient(publicId models.RendezVousId, wsproxyMode bool) (*TestClient, error) {
|
|
client := NewTestClient(s.ctx)
|
|
var clientId models.ClientId
|
|
res := testsupport.RunAndWait(
|
|
&s.Suite,
|
|
func() any {
|
|
//server
|
|
clientIdCreated, synchronizer, err := s.matchMaker.Connect(wsproxyMode, publicId, client.serverSIdeConn)
|
|
clientId = clientIdCreated
|
|
if err == nil {
|
|
log.Println("test: synchronizing streams.")
|
|
go synchronizer()
|
|
}
|
|
return err
|
|
},
|
|
func() any {
|
|
if wsproxyMode {
|
|
client.WsproxyInit()
|
|
}
|
|
return nil
|
|
})
|
|
|
|
if res[0] != nil {
|
|
return client, res[0].(error)
|
|
}
|
|
client.publicId = publicId
|
|
client.clientId = clientId
|
|
return client, nil
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) checkState(nAgents int, nClients int) {
|
|
s.True(testsupport.CheckCondition(s.ctx, func() bool {
|
|
return nAgents == len(s.notifier.state.Agents)
|
|
}))
|
|
s.True(testsupport.CheckCondition(s.ctx, func() bool {
|
|
return nClients == len(s.notifier.state.Clients)
|
|
}))
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) registerAgent(publicId models.RendezVousId, agent *TestAgent) WaitForAgentFunc {
|
|
res := testsupport.RunAndWait(
|
|
&s.Suite,
|
|
func() any {
|
|
// ignore waitFunc for now.
|
|
waitFunc, err := s.matchMaker.Register(publicId, agent.serverSIdeConn)
|
|
s.Nil(err)
|
|
log.Printf("MatchMaskerTest: Agent registered by server")
|
|
return waitFunc
|
|
},
|
|
func() any {
|
|
_, err := agent.Initialize(s)
|
|
if err != nil {
|
|
s.Nil(err)
|
|
return nil
|
|
}
|
|
err = agent.Register(s)
|
|
if err != nil {
|
|
s.Nil(err)
|
|
return nil
|
|
}
|
|
log.Println("MatchMakerTest: Agent connected to server")
|
|
return nil
|
|
})
|
|
|
|
if res[0] == nil {
|
|
return nil
|
|
}
|
|
return res[0].(WaitForAgentFunc)
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) Test_ConnectWsproxyMode() {
|
|
publicId := models.RendezVousId("abc")
|
|
agent := NewTestAgent(s.ctx)
|
|
|
|
waitForAgentFunc := s.registerAgent(publicId, agent)
|
|
go waitForAgentFunc()
|
|
|
|
client, err := s.connectClient(publicId, true)
|
|
s.Nil(err)
|
|
|
|
s.checkState(1, 1)
|
|
|
|
agentClientSideConn, err := agent.listener.GetConnection(string(client.clientId))
|
|
log.Printf("Agent side conn %v", agentClientSideConn)
|
|
s.Nil(err)
|
|
testsupport.BidirectionalConnectionCheck(
|
|
&s.Suite, "testmsg",
|
|
client.clientSideConn,
|
|
agentClientSideConn)
|
|
|
|
s.True(client.clientConnectionInfo.Ok)
|
|
|
|
client.Disconnect()
|
|
// It is the agents choice to exit> The test agent does not exit by default when
|
|
// there are no more connections.
|
|
s.checkState(1, 0)
|
|
}
|
|
|
|
func (s *MatchMakerTestSuite) Test_ConnectWsproxyModeAgentNotFound() {
|
|
publicId := models.RendezVousId("abc")
|
|
|
|
client, err := s.connectClient(publicId, true)
|
|
s.NotNil(err)
|
|
s.True(strings.Contains(err.Error(), "No agent found for rendez-vous id"))
|
|
s.False(client.clientConnectionInfo.Ok)
|
|
|
|
s.checkState(0, 0)
|
|
}
|