package matchmaker

import (
	"context"
	"git.wamblee.org/converge/pkg/comms"
	"git.wamblee.org/converge/pkg/models"
	"git.wamblee.org/converge/pkg/support/ioutils"
	"git.wamblee.org/converge/pkg/testsupport"
	"github.com/stretchr/testify/suite"
	"go.uber.org/goleak"
	"io"
	"log"
	"net/http"
	"os"
	"strings"
	"sync"
	"testing"
	"time"
)

type MatchMakerTestSuite struct {
	suite.Suite

	ctx         context.Context
	cancelFunc  context.CancelFunc
	pprofServer *http.Server

	notifier   *TestNotifier
	matchMaker *MatchMaker
}

type TestNotifier struct {
	// last reported state
	state *models.State
}

func (notifier *TestNotifier) Publish(state *models.State) {
	notifier.state = state
}

func (s *MatchMakerTestSuite) SetupSuite() {
	s.pprofServer = testsupport.StartPprof("")
}

func (s *MatchMakerTestSuite) TearDownSuite() {
	testsupport.StopPprof(s.ctx, s.pprofServer)
}

func (s *MatchMakerTestSuite) SetupTest() {
	ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second)
	s.ctx = ctx
	s.cancelFunc = cancelFunc

	s.notifier = &TestNotifier{}
	s.matchMaker = NewMatchMaker(s.notifier)
}

func (s *MatchMakerTestSuite) TearDownTest() {
	s.matchMaker.Close()
	s.cancelFunc()
	goleak.VerifyNone(s.T())
}

func TestMatchMakerTestSuite(t *testing.T) {
	suite.Run(t, &MatchMakerTestSuite{})
}

type TestAgent struct {
	agentSideConn  io.ReadWriteCloser
	serverSIdeConn io.ReadWriteCloser

	agentRegistration comms.AgentRegistration
	commChannel       comms.CommChannel
	listener          *testsupport.TestAgentListener
}

func NewTestAgent(ctx context.Context) *TestAgent {
	res := TestAgent{}
	a, s := testsupport.CreatePipe(ctx)
	res.agentSideConn = a
	res.serverSIdeConn = s
	return &res
}

func (agent *TestAgent) Disconnect() {
	agent.agentSideConn.Close()
}

func (agent *TestAgent) Initialize(s *MatchMakerTestSuite) (comms.ServerInfo, error) {
	return comms.AgentInitialization(agent.agentSideConn, comms.NewEnvironmentInfo("bash"))
}

func (agent *TestAgent) Register(s *MatchMakerTestSuite) error {
	agentRegistration, err := comms.ReceiveRegistrationMessage(agent.agentSideConn)
	if err != nil {
		return err
	}
	agent.agentRegistration = agentRegistration
	commChannel, err := comms.NewCommChannel(comms.Agent, agent.agentSideConn)
	if err != nil {
		return err
	}
	s.NotNil(commChannel)
	agent.commChannel = commChannel

	baseListener := comms.NewAgentListener(commChannel.Session)
	agent.listener = testsupport.NewTestListener(s.ctx, baseListener)
	return nil
}

type TestClient struct {
	clientSideConn io.ReadWriteCloser
	serverSIdeConn ioutils.ReadWriteAddrCloser
	publicId       models.RendezVousId
	clientId       models.ClientId

	// for wsproxy mode
	serverProtocol       comms.ProtocolVersion
	clientConnectionInfo comms.ClientConnectionInfo
}

func NewTestClient(ctx context.Context) *TestClient {
	a, b := testsupport.CreatePipe(ctx)
	res := TestClient{
		clientSideConn: a,
		serverSIdeConn: ioutils.NewSimpleReadWriteAddrCloser(b,
			testsupport.DummyRemoteAddr("remoteaddr")),
	}
	return &res
}

func (c *TestClient) Disconnect() {
	c.clientSideConn.Close()
}

func (c *TestClient) WsproxyInit() error {
	channel := comms.NewGOBChannel(c.clientSideConn)
	serverProtocol, err := comms.ReceiveWithTimeout[comms.ProtocolVersion](channel)
	if err != nil {
		return err
	}
	c.serverProtocol = serverProtocol

	clientConnectionInfo, err := comms.ReceiveWithTimeout[comms.ClientConnectionInfo](channel)
	if err != nil {
		return err
	}
	c.clientConnectionInfo = clientConnectionInfo
	err = comms.SendWithTimeout(channel, comms.NewEnvironmentInfo(os.Getenv("SHELL")))
	if err != nil {
		return err
	}
	return nil
}

func (s *MatchMakerTestSuite) Test_newMatchMaker() {
	s.checkState(0, 0)
}

func (s *MatchMakerTestSuite) Test_singleAgent() {
	publicId := models.RendezVousId("abc")
	agent := NewTestAgent(s.ctx)
	waitForAgentFunc := s.registerAgent(publicId, agent)

	s.checkState(1, 0)

	// required for connection loss detection
	go waitForAgentFunc()

	agent.Disconnect()
	s.checkState(0, 0)
}

func (s *MatchMakerTestSuite) Test_singleAgentAndClient() {
	publicId := models.RendezVousId("abc")
	agent := NewTestAgent(s.ctx)

	waitForAgentFunc := s.registerAgent(publicId, agent)
	go waitForAgentFunc()

	client, err := s.connectClient(publicId, false)
	s.Nil(err)

	s.checkState(1, 1)

	agentClientSideConn, err := agent.listener.GetConnection(string(client.clientId))
	log.Printf("Agent side conn %v", agentClientSideConn)
	s.Nil(err)
	testsupport.BidirectionalConnectionCheck(
		&s.Suite, "testmsg",
		client.clientSideConn,
		agentClientSideConn)

	client.Disconnect()
	// It is the agents choice to exit> The test agent does not exit by default when
	// there are no more connections.
	s.checkState(1, 0)
}

func (s *MatchMakerTestSuite) Test_ConnectCLientToUnknownAgent() {
	publicId := models.RendezVousId("abc")

	_, err := s.connectClient(publicId, false)
	s.NotNil(err)
	s.checkState(0, 0)
}

func (s *MatchMakerTestSuite) Test_multipleAgentsAndClients() {
	agents := []string{"abc", "def", "ghi"}
	clients := map[string]int{"abc": 3, "def": 2, "ghi": 5}
	log.Printf("agents %v, clients %v", agents, clients)
	testAgents := make(map[string]*TestAgent)
	for _, publicId := range agents {
		agent := NewTestAgent(s.ctx)
		testAgents[publicId] = agent
		waitForAgentFunc := s.registerAgent(models.RendezVousId(publicId), agent)
		s.Require().NotNil(waitForAgentFunc)
		go waitForAgentFunc()
	}
	testClients := make([]*TestClient, 0)
	for publicId, nclients := range clients {
		for range nclients {
			client, err := s.connectClient(models.RendezVousId(publicId), false)
			s.Require().Nil(err)
			testClients = append(testClients, client)
		}
	}

	wg := sync.WaitGroup{}

	// bidirectional connection test
	for _, testClient := range testClients {
		wg.Add(1)
		go func() {
			defer wg.Done()
			agent := testAgents[string(testClient.publicId)]
			agentClientSideConn, err := agent.listener.GetConnection(string(testClient.clientId))
			s.Nil(err)
			log.Printf("Testing bi-directional commication client %v agent %v", testClient.clientId, testClient.publicId)
			testsupport.BidirectionalConnectionCheck(
				&s.Suite, "testmsg"+string(testClient.clientId),
				testClient.clientSideConn,
				agentClientSideConn)
		}()
	}
	wg.Wait()
}

func (s *MatchMakerTestSuite) connectClient(publicId models.RendezVousId, wsproxyMode bool) (*TestClient, error) {
	client := NewTestClient(s.ctx)
	var clientId models.ClientId
	res := testsupport.RunAndWait(
		&s.Suite,
		func() any {
			//server
			clientIdCreated, synchronizer, err := s.matchMaker.Connect(wsproxyMode, publicId, client.serverSIdeConn)
			clientId = clientIdCreated
			if err == nil {
				log.Println("test: synchronizing streams.")
				go synchronizer()
			}
			return err
		},
		func() any {
			if wsproxyMode {
				client.WsproxyInit()
			}
			return nil
		})

	if res[0] != nil {
		return client, res[0].(error)
	}
	client.publicId = publicId
	client.clientId = clientId
	return client, nil
}

func (s *MatchMakerTestSuite) checkState(nAgents int, nClients int) {
	s.True(testsupport.CheckCondition(s.ctx, func() bool {
		return nAgents == len(s.notifier.state.Agents)
	}))
	s.True(testsupport.CheckCondition(s.ctx, func() bool {
		return nClients == len(s.notifier.state.Clients)
	}))
}

func (s *MatchMakerTestSuite) registerAgent(publicId models.RendezVousId, agent *TestAgent) WaitForAgentFunc {
	res := testsupport.RunAndWait(
		&s.Suite,
		func() any {
			// ignore waitFunc for now.
			waitFunc, err := s.matchMaker.Register(publicId, agent.serverSIdeConn)
			s.Nil(err)
			log.Printf("MatchMaskerTest: Agent registered by server")
			return waitFunc
		},
		func() any {
			_, err := agent.Initialize(s)
			if err != nil {
				s.Nil(err)
				return nil
			}
			err = agent.Register(s)
			if err != nil {
				s.Nil(err)
				return nil
			}
			log.Println("MatchMakerTest: Agent connected to server")
			return nil
		})

	if res[0] == nil {
		return nil
	}
	return res[0].(WaitForAgentFunc)
}

func (s *MatchMakerTestSuite) Test_ConnectWsproxyMode() {
	publicId := models.RendezVousId("abc")
	agent := NewTestAgent(s.ctx)

	waitForAgentFunc := s.registerAgent(publicId, agent)
	go waitForAgentFunc()

	client, err := s.connectClient(publicId, true)
	s.Nil(err)

	s.checkState(1, 1)

	agentClientSideConn, err := agent.listener.GetConnection(string(client.clientId))
	log.Printf("Agent side conn %v", agentClientSideConn)
	s.Nil(err)
	testsupport.BidirectionalConnectionCheck(
		&s.Suite, "testmsg",
		client.clientSideConn,
		agentClientSideConn)

	s.True(client.clientConnectionInfo.Ok)

	client.Disconnect()
	// It is the agents choice to exit> The test agent does not exit by default when
	// there are no more connections.
	s.checkState(1, 0)
}

func (s *MatchMakerTestSuite) Test_ConnectWsproxyModeAgentNotFound() {
	publicId := models.RendezVousId("abc")

	client, err := s.connectClient(publicId, true)
	s.NotNil(err)
	s.True(strings.Contains(err.Error(), "No agent found for rendez-vous id"))
	s.False(client.clientConnectionInfo.Ok)

	s.checkState(0, 0)
}