package comms

import (
	"encoding/gob"
	"fmt"
	"github.com/hashicorp/yamux"
	"io"
	"log"
	"net"
	"time"
)

type CommChannel struct {
	Peer    net.Conn
	Encoder *gob.Encoder
	Decoder *gob.Decoder
	Session *yamux.Session
}

type AgentListener interface {
	AgentInfo(agent AgentInfo)
	SessionInfo(session SessionInfo)
	ExpiryTimeUpdate(session ExpiryTimeUpdate)
}

type Role int

const (
	Agent Role = iota
	ConvergeServer
)

func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
	var commChannel CommChannel
	switch role {
	case Agent:
		listener, err := yamux.Server(wsConn, nil)
		if err != nil {
			return CommChannel{}, err
		}
		commChannel = CommChannel{
			Peer:    nil,
			Session: listener,
		}
	case ConvergeServer:
		clientSession, err := yamux.Client(wsConn, nil)
		if err != nil {
			return CommChannel{}, err
		}
		commChannel = CommChannel{
			Peer:    nil,
			Session: clientSession,
		}
	default:
		panic(fmt.Errorf("Undefined role %d", role))
	}

	// communication between Agent and ConvergeServer
	// Currently used only fof communication from Agent to ConvergeServer

	switch role {
	case Agent:
		conn, err := commChannel.Session.OpenStream()
		commChannel.Peer = conn
		if err != nil {
			return CommChannel{}, err
		}
	case ConvergeServer:
		conn, err := commChannel.Session.Accept()
		commChannel.Peer = conn
		if err != nil {
			return CommChannel{}, err
		}
	default:
		panic(fmt.Errorf("Undefined role %d", role))
	}
	log.Println("Communication channel between agent and converge server established")

	RegisterEventsWithGob()
	commChannel.Encoder = gob.NewEncoder(commChannel.Peer)
	commChannel.Decoder = gob.NewDecoder(commChannel.Peer)

	// heartbeat
	if role == Agent {
		go func() {
			for {
				time.Sleep(10 * time.Second)
				err := Send(commChannel, HeartBeat{})
				if err != nil {
					log.Println("Sending heartbeat to server failed")
				}
			}
		}()
	}

	return commChannel, nil
}

// Sending an event to the other side

func Send(commChannel CommChannel, object any) error {
	err := commChannel.Encoder.Encode(ConvergeMessage{Value: object})
	if err != nil {
		log.Printf("Encoding error %v", err)
	}
	return err
}

func ListenForAgentEvents(channel CommChannel,
	agentInfo func(agent AgentInfo),
	sessionInfo func(session SessionInfo),
	expiryTimeUpdate func(session ExpiryTimeUpdate)) {
	for {
		var result ConvergeMessage
		err := channel.Decoder.Decode(&result)

		if err != nil {
			// TODO more clean solution, need to explicitly close when agent exits.
			log.Printf("Exiting agent listener %v", err)
			return
		}
		switch v := result.Value.(type) {

		case AgentInfo:
			agentInfo(v)

		case SessionInfo:
			sessionInfo(v)

		case ExpiryTimeUpdate:
			expiryTimeUpdate(v)

		case HeartBeat:
			// for not ignoring, can also implement behavior
			// when heartbeat not received but hearbeat is only
			// intended to keep the connection up

		default:
			fmt.Printf("  Unknown type: %T\n", v)
		}
	}
}

func ListenForServerEvents(channel CommChannel,
	setUsernamePassword func(user UserPassword)) {
	for {
		var result ConvergeMessage
		err := channel.Decoder.Decode(&result)

		if err != nil {
			// TODO more clean solution, need to explicitly close when agent exits.
			log.Printf("Exiting agent listener %v", err)
			return
		}
		switch v := result.Value.(type) {

		case UserPassword:
			setUsernamePassword(v)

		default:
			fmt.Printf("  Unknown type: %T\n", v)
		}
	}
}