package comms

import (
	"bytes"
	"context"
	"git.wamblee.org/converge/pkg/testsupport"
	"github.com/stretchr/testify/suite"
	"go.uber.org/goleak"
	"io"
	"log"
	"math/rand"
	"net/http"
	"strings"
	"testing"
	"time"
)

type AgentServerTestSuite struct {
	suite.Suite

	ctx         context.Context
	cancelFunc  context.CancelFunc
	pprofServer *http.Server

	agentReadWriter  io.ReadWriteCloser
	serverReadWriter io.ReadWriteCloser
}

func (s *AgentServerTestSuite) SetupSuite() {
	s.pprofServer = testsupport.StartPprof("")
}

func (s *AgentServerTestSuite) TearDownSuite() {
	testsupport.StopPprof(s.ctx, s.pprofServer)
}

func (s *AgentServerTestSuite) SetupTest() {
	ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second)
	s.ctx = ctx
	s.cancelFunc = cancelFunc

	// Could have also used net.Pipe but net.Pipe uses synchronous communication
	// by default and the bitpipe implementation can become asynchronous when
	// a channels ize > 0 is passed in. Also the test utility respects the context
	// so also deals with cancellation much better than net.Pipe.
	bitpipe := testsupport.NewInmemoryConnection(s.ctx, "inmemory", 10)
	agentReadWriter := bitpipe.Front()
	serverReadWriter := bitpipe.Back()
	s.agentReadWriter = agentReadWriter
	s.serverReadWriter = serverReadWriter
}

func (s *AgentServerTestSuite) TearDownTest() {
	agentProtocolVersion = PROTOCOL_VERSION
	serverProtocolVersion = PROTOCOL_VERSION
	s.cancelFunc()
	goleak.VerifyNone(s.T())
}

func TestAgentServerTestSuite(t *testing.T) {
	suite.Run(t, &AgentServerTestSuite{})
}

func (s *AgentServerTestSuite) TestClientInfoEncodeDecode() {
	for _, clientId := range []string{"abc", "djkdfadfha"} {
		buf := &bytes.Buffer{}
		err := SendClientInfo(buf, ClientInfo{ClientId: clientId})
		s.Nil(err)
		clientIdReceived, err := ReceiveClientInfo(buf)
		s.Nil(err)
		s.Equal(clientId, clientIdReceived.ClientId)
	}
}

func (s *AgentServerTestSuite) createCommChannel() (CommChannel, CommChannel) {
	commChannels := testsupport.RunAndWait(
		&s.Suite,
		func() any {
			log.Println("Agent initializing")
			commChannel, err := NewCommChannel(Agent, s.agentReadWriter)
			s.Nil(err)
			return commChannel
		},
		func() any {
			log.Println("Server initializing")
			commChannel, err := NewCommChannel(ConvergeServer, s.serverReadWriter)
			s.Nil(err)
			return commChannel
		},
	)

	s.Equal(2, len(commChannels))
	agentCommChannel := commChannels[0].(CommChannel)
	serverCommChannel := commChannels[1].(CommChannel)
	return agentCommChannel, serverCommChannel
}

func (s *AgentServerTestSuite) TestNewCommChannel() {
	// Setup Comm channel
	agentCommChannel, serverCommChannel := s.createCommChannel()

	// verify the side channel is working by sending an object
	testsupport.RunAndWait(
		&s.Suite,
		func() any {
			protocolVersion := ProtocolVersion{Version: 10}
			err := SendWithTimeout[ProtocolVersion](agentCommChannel.SideChannel, protocolVersion)
			s.Nil(err)
			log.Printf("Sent one message %v", protocolVersion)
			return nil
		},
		func() any {
			protocolVersion, err := ReceiveWithTimeout[ProtocolVersion](serverCommChannel.SideChannel)
			s.Nil(err)
			log.Printf("Received one message %v", protocolVersion)
			return nil
		},
	)

	log.Printf("%v %v", agentCommChannel, serverCommChannel)
}

func (s *AgentServerTestSuite) Test_ConnectThroughYamux() {
	agentCommChannel, serverCommChannel := s.createCommChannel()

	dataAgentToServer := "hello"
	dataServerToAgent := "bye"
	testsupport.RunAndWait(
		&s.Suite,
		func() any {
			conn, err := agentCommChannel.Session.OpenStream()
			s.Nil(err)
			testsupport.AssertWriteData(&s.Suite, dataAgentToServer, conn)
			testsupport.AssertReadData(&s.Suite, dataServerToAgent, conn)
			return nil
		},
		func() any {
			conn, err := serverCommChannel.Session.Accept()

			s.Nil(err)
			testsupport.AssertReadData(&s.Suite, dataAgentToServer, conn)
			testsupport.AssertWriteData(&s.Suite, dataServerToAgent, conn)
			return nil
		})
}

func (s *AgentServerTestSuite) Test_Initialization() {
	serverTestId := rand.Int()
	agentShell := "agentshell"
	testsupport.RunAndWait(
		&s.Suite,
		func() any {
			serverInfo, err := AgentInitialization(s.agentReadWriter,
				NewEnvironmentInfo(agentShell))
			s.Nil(err)
			s.Equal(serverTestId, serverInfo.TestId)
			return nil
		},
		func() any {
			serverInfo := ServerInfo{TestId: serverTestId}
			environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
			s.Nil(err)
			s.Equal(agentShell, environmentInfo.Shell)
			return nil
		})
}

func (s *AgentServerTestSuite) Test_InitializationProtocolVersionMismatch() {
	serverProtocolVersion++
	testsupport.RunAndWait(
		&s.Suite,
		func() any {
			serverInfo, err := AgentInitialization(s.agentReadWriter,
				NewEnvironmentInfo("myshell"))
			s.NotNil(err)
			s.True(strings.Contains(strings.ToLower(err.Error()), "protocol"))
			s.Equal(ServerInfo{}, serverInfo)
			return nil
		},
		func() any {
			serverInfo := ServerInfo{TestId: 1000}
			environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
			s.NotNil(err)
			s.True(strings.Contains(strings.ToLower(err.Error()), "protocol"))
			s.Equal(EnvironmentInfo{}, environmentInfo)
			return nil
		})
}

func (s *AgentServerTestSuite) Test_InitializationAgentConnectionClosed() {
	s.agentReadWriter.Close()
	s.checkInitializationFailure()
}

func (s *AgentServerTestSuite) Test_InitializationServerConnectionClosed() {
	s.serverReadWriter.Close()
	s.checkInitializationFailure()
}

func (s *AgentServerTestSuite) checkInitializationFailure() []any {
	return testsupport.RunAndWait(
		&s.Suite,
		func() any {
			serverInfo, err := AgentInitialization(s.agentReadWriter,
				NewEnvironmentInfo("myshell"))
			s.NotNil(err)
			s.Equal(ServerInfo{}, serverInfo)
			return nil
		},
		func() any {
			serverInfo := ServerInfo{TestId: 1000}
			environmentInfo, err := ServerInitialization(s.serverReadWriter, serverInfo)
			s.NotNil(err)
			s.Equal(EnvironmentInfo{}, environmentInfo)
			return nil
		})
}

// TODO:
// Tests when connection is close from agent and from server: verify error is returned

func (s *AgentServerTestSuite) Test_ListenForAgentEvents() {

	agentEvents := []any{
		NewEnvironmentInfo("myshell"),
		NewSessionInfo("1", "sftp"),
		NewExpiryTimeUpdate(time.Now().Add(1 * time.Minute)),
		HeartBeat{},
	}
	const nevents = 100
	eventTypesSent := make([]int, nevents, nevents)
	testsupport.RunAndWait(
		&s.Suite,
		func() any {
			channel := NewGOBChannel(s.agentReadWriter)
			for i := range nevents {
				ievent := rand.Int() % len(agentEvents)
				eventTypesSent[i] = ievent
				event := ConvergeMessage{
					Value: agentEvents[ievent],
				}
				err := SendWithTimeout[ConvergeMessage](channel, event)
				s.Nil(err)
			}
			// pending events will still be sent.
			s.agentReadWriter.Close()
			return nil
		},
		func() any {
			eventTypesReceived := make([]int, nevents, nevents)
			channel := NewGOBChannel(s.serverReadWriter)
			i := 0
			ListenForAgentEvents(channel,
				func(agent EnvironmentInfo) {
					eventTypesReceived[i] = 0
					i++
				},
				func(session SessionInfo) {
					eventTypesReceived[i] = 1
					i++
				},
				func(expiryTimeUpdate ExpiryTimeUpdate) {
					eventTypesReceived[i] = 2
					i++
				},
				func(hearbeat HeartBeat) {
					eventTypesReceived[i] = 3
					i++
				},
			)
			s.Equal(eventTypesSent, eventTypesReceived)
			return nil
		})
}

// This is currently a Noop. No need to test it.
//func (s *AgentServerTestSuite) Test_LIstenForServerEvents() {
//
//}