package comms

import (
	"fmt"
	"github.com/hashicorp/yamux"
	"io"
	"log"
	"time"
)

const MESSAGE_TIMEOUT = 10 * time.Second

type CommChannel struct {
	// a separate connection outside of the ssh session
	SideChannel GOBChannel
	Session     *yamux.Session
}

type Role int

const (
	Agent Role = iota
	ConvergeServer
)

func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
	var commChannel CommChannel

	switch role {
	case Agent:
		listener, err := yamux.Server(wsConn, nil)
		if err != nil {
			return CommChannel{}, err
		}
		commChannel = CommChannel{
			Session: listener,
		}
	case ConvergeServer:
		clientSession, err := yamux.Client(wsConn, nil)
		if err != nil {
			return CommChannel{}, err
		}
		commChannel = CommChannel{
			Session: clientSession,
		}
	default:
		panic(fmt.Errorf("Undefined role %d", role))
	}

	// communication between Agent and ConvergeServer
	// Currently used only fof communication from Agent to ConvergeServer

	switch role {
	case Agent:
		conn, err := commChannel.Session.OpenStream()
		commChannel.SideChannel = NewGOBChannel(conn)
		if err != nil {
			return CommChannel{}, err
		}
	case ConvergeServer:
		conn, err := commChannel.Session.Accept()
		commChannel.SideChannel = NewGOBChannel(conn)
		if err != nil {
			return CommChannel{}, err
		}
	default:
		panic(fmt.Errorf("Undefined role %d", role))
	}
	log.Println("Communication channel between agent and converge server established")

	return commChannel, nil
}

// Communication from agent to server during the session.

func SetupHeartBeat(commChannel CommChannel) {
	go func() {
		for {
			time.Sleep(10 * time.Second)
			err := Send(commChannel.SideChannel,
				ConvergeMessage{
					Value: HeartBeat{},
				})
			if err != nil {
				log.Println("Sending heartbeat to server failed")
			}
		}
	}()
}

func ListenForAgentEvents(channel GOBChannel,
	agentInfo func(agent EnvironmentInfo),
	sessionInfo func(session SessionInfo),
	expiryTimeUpdate func(session ExpiryTimeUpdate)) {
	for {
		var result ConvergeMessage
		err := channel.Decoder.Decode(&result)

		if err != nil {
			log.Printf("Exiting agent listener %v", err)
			return
		}
		switch v := result.Value.(type) {

		case EnvironmentInfo:
			agentInfo(v)

		case SessionInfo:
			sessionInfo(v)

		case ExpiryTimeUpdate:
			expiryTimeUpdate(v)

		case HeartBeat:
			// for not ignoring, can also implement behavior
			// when heartbeat not received but hearbeat is only
			// intended to keep the connection up

		default:
			fmt.Printf("  Unknown type: %v %T\n", v, v)
		}
	}
}

func ListenForServerEvents(channel CommChannel) {
	for {
		var result ConvergeMessage
		err := channel.SideChannel.Decoder.Decode(&result)

		if err != nil {
			log.Printf("Exiting agent listener %v", err)
			return
		}

		// no supported server events at this time.
		switch v := result.Value.(type) {

		default:
			fmt.Printf("  Unknown type: %T\n", v)
		}
	}
}

func AgentInitialization(conn io.ReadWriter, agentInto EnvironmentInfo) (ServerInfo, error) {
	channel := NewGOBChannel(conn)
	err := CheckProtocolVersion(Agent, channel)

	err = SendWithTimeout(channel, agentInto)
	if err != nil {
		return ServerInfo{}, nil
	}
	serverInfo, err := ReceiveWithTimeout[ServerInfo](channel)
	if err != nil {
		return ServerInfo{}, nil
	}
	// TODO remove logging
	log.Println("Agent configuration received from server")

	return serverInfo, err
}

func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (EnvironmentInfo, error) {
	channel := NewGOBChannel(conn)
	err := CheckProtocolVersion(ConvergeServer, channel)

	agentInfo, err := ReceiveWithTimeout[EnvironmentInfo](channel)
	if err != nil {
		return EnvironmentInfo{}, err
	}
	log.Println("Agent info received: ", agentInfo)
	err = SendWithTimeout(channel, serverInfo)
	if err != nil {
		return EnvironmentInfo{}, nil
	}
	return agentInfo, err
}

// Events sent over the websocket connection that is established between
// agent and converge server. This is done as soon as the agent starts.

// First commmunication between agent and Converge Server
// Both exchange their protocol version and if it is incorrect, the session
// is terminated.

func CheckProtocolVersion(role Role, channel GOBChannel) error {
	switch role {
	case Agent:
		err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
		if err != nil {
			return err
		}
		version, err := ReceiveWithTimeout[ProtocolVersion](channel)
		if err != nil {
			return err
		}
		if version.Version != PROTOCOL_VERSION {
			return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
				PROTOCOL_VERSION, version.Version)
		}
		return nil
	case ConvergeServer:
		version, err := ReceiveWithTimeout[ProtocolVersion](channel)
		if err != nil {
			return err
		}
		err = SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
		if err != nil {
			return err
		}
		if version.Version != PROTOCOL_VERSION {
			return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
				PROTOCOL_VERSION, version.Version)
		}
		return nil
	default:
		panic(fmt.Errorf("unexpected rolg %v", role))
	}
}

// Session info metadata exchange. These are sent over the SSH connection. The agent embedded
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
// metadata before the connection is handed back to SSH.

func SendClientInfo(conn io.ReadWriter, info ClientInfo) error {
	channel := NewGOBChannel(conn)
	return SendWithTimeout(channel, info)
}

func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) {
	channel := NewGOBChannel(conn)
	clientInfo, err := ReceiveWithTimeout[ClientInfo](channel)
	if err != nil {
		return ClientInfo{}, err
	}
	return clientInfo, nil
}

// message sent on the initial connection from server to agent to confirm the registration

func SendRegistrationMessage(conn io.ReadWriter, registration AgentRegistration) error {
	channel := NewGOBChannel(conn)
	return SendWithTimeout(channel, registration)
}

func ReceiveRegistrationMessage(conn io.ReadWriter) (AgentRegistration, error) {
	channel := NewGOBChannel(conn)
	registration, err := ReceiveWithTimeout[AgentRegistration](channel)
	if err != nil {
		return AgentRegistration{}, err
	}
	return registration, nil
}