package main

import (
	"converge/pkg/comms"
	"converge/pkg/support/iowrappers"
	"converge/pkg/support/websocketutil"
	"crypto/tls"
	"fmt"
	"github.com/gorilla/websocket"
	"log"
	"net"
	"net/http"
	"net/url"
	"os"
	"strings"
	"time"
)

func closeConnection(conn net.Conn) {
	if tcpConn, ok := conn.(*net.TCPConn); ok {
		tcpConn.SetLinger(0)
	}
	_ = conn.Close()
}

type Stdio struct{}

func (stdio Stdio) Read(b []byte) (n int, err error) {
	return os.Stdin.Read(b)
}
func (stdio Stdio) Write(b []byte) (n int, err error) {
	return os.Stdout.Write(b)
}

func getArg(args []string) (value string, ret []string) {
	if len(args) < 2 {
		printHelp(fmt.Sprintf("The '%s' option expects an argument", args[0]))
	}
	return args[1], args[1:]
}

func printHelp(msg string) {
	if msg != "" {
		fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", msg)
	}
	usage := "Usage: wsproxy [--id <ID>] [--insecure] ws[s]://<host>[:port]/client/<ID>\n" +
		"\n" +
		"Here <ID> is the rendez-vous id of a continuous integration job\n" +
		"\n" +
		"Use this in an ssh command like this: \n" +
		"\n" +
		"  ssh -oServerAliveInterval=10 -oProxyCommand='wsproxy ws[s]://<host>[:port]/client/<ID>' abc@localhost\n" +
		"\n" +
		"The above ssh commmand connects to the converge server listening on the websocket through \n" +
		"wsproxy. \n" +
		"\n" +
		"Options:\n" +
		"\n" +
		"--insecure:     Skip certificate validation when used over a secure connection.\n" +
		"--raw:          Just use wsproxy as a raw stdio to websocket proxy, this disables messages\n" +
		"                from the server so the user will not get clear error messages.\n"
	fmt.Fprintln(os.Stderr, usage)
	os.Exit(1)
}

func main() {

	insecure := false
	raw := false

	args := os.Args[1:]
	for len(args) > 0 && strings.HasPrefix(args[0], "-") {
		switch args[0] {
		case "--insecure":
			insecure = true
		case "--raw":
			raw = true
		default:
			printHelp("Unknown option " + args[0])
		}
		args = args[1:]
	}

	if len(args) != 1 {
		printHelp("")
	}

	wsURL := args[0]
	urlParsed, err := url.Parse(wsURL)
	if err != nil {
		printHelp(fmt.Sprintf("Url '%s' is not valid", wsURL))
	}
	if !raw {
		if len(urlParsed.Query()) > 0 {
			printHelp("When not used in raw mode, wsproxy does not accept query parammeters in the URL")
		}
		wsURL += "?wsproxy"
	}

	dialer := websocket.Dialer{
		Proxy:            http.ProxyFromEnvironment,
		HandshakeTimeout: 45 * time.Second,
	}
	if insecure {
		dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
	}
	_wsConn, _, err := dialer.Dial(wsURL, nil)
	if err != nil {
		log.Println("WebSocket connection error:", err)
		panic(err)
	}
	wsConn := websocketutil.NewWebSocketConn(_wsConn, false)
	defer wsConn.Close()

	if !raw {
		channel := comms.NewGOBChannel(wsConn)
		// receive protocol version
		protocolVersion, err := comms.ReceiveWithTimeout[comms.ProtocolVersion](channel)
		if err != nil {
			log.Printf("Error receiving protocol version %v", err)
			os.Exit(1)
		}
		if protocolVersion.Version != comms.PROTOCOL_VERSION {
			log.Printf("Protocol version mismmatch: client %d, server %d",
				comms.PROTOCOL_VERSION, protocolVersion.Version)
			os.Exit(1)
		}
		// receive confirmation about the agent id.
		clientConnectionInfo, err := comms.ReceiveWithTimeout[comms.ClientConnectionInfo](channel)
		if err != nil {
			log.Printf("Error receiving client connection info: %v", err)
			os.Exit(1)
		}
		if clientConnectionInfo.Ok {
			log.Printf("Client connection was accepted, agent is available.")
		} else {
			log.Printf("Error reported by server: %v", clientConnectionInfo.Message)
			os.Exit(1)
		}
	}

	dataConnection := wsConn

	iowrappers.SynchronizeStreams(wsURL+" -- stdio", dataConnection, Stdio{})
}