package comms

import (
	"converge/pkg/support/websocketutil"
	"net"
	"strconv"
)

type AgentListener struct {
	decorated net.Listener
}

func NewAgentListener(listener net.Listener) AgentListener {
	return AgentListener{decorated: listener}
}

type LocalAddrHackConn struct {
	net.Conn
	localAddr net.Addr
}

func (conn LocalAddrHackConn) LocalAddr() net.Addr {
	return conn.localAddr
}

func NewLocalAddrHackConn(conn net.Conn, clientId string) LocalAddrHackConn {
	addr := LocalAddrHackConn{
		localAddr: websocketutil.WebSocketAddr(clientId),
	}
	addr.Conn = conn
	return addr
}

func (listener AgentListener) Accept() (net.Conn, error) {
	conn, err := listener.decorated.Accept()
	if err != nil {
		return nil, err
	}

	clientInfo, err := ReceiveClientInfo(conn)
	if err != nil {
		conn.Close()
		return nil, err
	}
	return NewLocalAddrHackConn(conn, strconv.Itoa(clientInfo.ClientId)), nil
}

func (listener AgentListener) Close() error {
	return listener.decorated.Close()
}

func (listener AgentListener) Addr() net.Addr {
	return listener.decorated.Addr()
}