package comms

import (
	"encoding/gob"
	"fmt"
	"github.com/hashicorp/yamux"
	"io"
	"log"
	"net"
	"os"
	"os/user"
	"time"
)

type CommChannel struct {
	Peer    net.Conn
	Encoder *gob.Encoder
	Decoder *gob.Decoder
	Session *yamux.Session
}

type Role int

const (
	Agent Role = iota
	ConvergeServer
)

func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
	var commChannel CommChannel
	switch role {
	case Agent:
		listener, err := yamux.Server(wsConn, nil)
		if err != nil {
			return CommChannel{}, err
		}
		commChannel = CommChannel{
			Peer:    nil,
			Session: listener,
		}
	case ConvergeServer:
		clientSession, err := yamux.Client(wsConn, nil)
		if err != nil {
			return CommChannel{}, err
		}
		commChannel = CommChannel{
			Peer:    nil,
			Session: clientSession,
		}
	default:
		panic(fmt.Errorf("Undefined role %d", role))
	}

	// communication between Agent and ConvergeServer
	// Currently used only fof communication from Agent to ConvergeServer

	switch role {
	case Agent:
		conn, err := commChannel.Session.OpenStream()
		commChannel.Peer = conn
		if err != nil {
			return CommChannel{}, err
		}
	case ConvergeServer:
		conn, err := commChannel.Session.Accept()
		commChannel.Peer = conn
		if err != nil {
			return CommChannel{}, err
		}
	default:
		panic(fmt.Errorf("Undefined role %d", role))
	}
	log.Println("Communication channel between agent and converge server established")

	gob.Register(SessionInfo{})
	gob.Register(ExpiryTimeUpdate{})
	gob.Register(ConvergeMessage{})

	commChannel.Encoder = gob.NewEncoder(commChannel.Peer)
	commChannel.Decoder = gob.NewDecoder(commChannel.Peer)

	switch role {
	case ConvergeServer:
		go serverReader(commChannel)
	}

	return commChannel, nil
}

func serverReader(channel CommChannel) {
	for {
		var result ConvergeMessage
		err := channel.Decoder.Decode(&result)

		if err != nil {
			// TODO more clean solution, need to explicitly close when agent exits.
			log.Printf("Exiting serverReader %v", err)
			return
		}
		switch v := result.Value.(type) {
		case SessionInfo:
			log.Println("RECEIVED: session info ", v)
		case ExpiryTimeUpdate:
			log.Println("RECEIVED: expirytime update ", v)
		default:
			fmt.Printf("  Unknown type: %T\n", v)
		}
	}
}

type SessionInfo struct {
	Username string
	Hostname string
	Pwd      string
}

func NewSessionInfo() SessionInfo {
	username, _ := user.Current()
	host, _ := os.Hostname()
	pwd, _ := os.Getwd()
	return SessionInfo{
		Username: username.Username,
		Hostname: host,
		Pwd:      pwd,
	}
}

type ExpiryTimeUpdate struct {
	ExpiryTime time.Time
}

type ConvergeMessage struct {
	Value interface{}
}

func NewExpiryTimeUpdate(expiryTime time.Time) ExpiryTimeUpdate {
	return ExpiryTimeUpdate{ExpiryTime: expiryTime}
}