package main

import (
	"bufio"
	"crypto/tls"
	"fmt"
	"git.wamblee.org/converge/pkg/agent/session"
	"git.wamblee.org/converge/pkg/agent/terminal"
	"git.wamblee.org/converge/pkg/comms"
	"git.wamblee.org/converge/pkg/support/iowrappers"
	"git.wamblee.org/converge/pkg/support/websocketutil"
	"github.com/gliderlabs/ssh"
	"github.com/gorilla/websocket"
	"github.com/pkg/sftp"
	"io"
	"log"
	"net"
	"net/http"
	"net/url"
	"os"
	"os/exec"
	"regexp"
	"runtime"
	"strconv"
	"strings"
	"time"

	_ "embed"
	_ "net/http/pprof"
)

func SftpHandler(sftpSession ssh.Session) {
	sessionInfo := comms.NewSessionInfo(
		sftpSession.LocalAddr().String(),
		"sftp",
	)
	session.Login(sessionInfo, sftpSession)
	defer session.LogOut(sessionInfo.ClientId)

	debugStream := io.Discard
	serverOptions := []sftp.ServerOption{
		sftp.WithDebug(debugStream),
	}
	// activity for sftp means that the server is sending data to the client.
	// In contrast to activity detection for ssh we use activity detection based on
	// serveractivity. This approach ensures that long downloads are not interrupted
	// by a timeout and are allowed to finish.
	activityDetector := NewSftpActivityDetector(sftpSession)
	server, err := sftp.NewServer(
		activityDetector,
		serverOptions...,
	)
	if err != nil {
		log.Printf("sftp tcpserver init error: %s\n", err)
		return
	}
	if err := server.Serve(); err == io.EOF {
		server.Close()
		fmt.Println("sftp client exited session.")
	} else if err != nil {
		fmt.Println("sftp tcpserver completed with error:", err)
	}
}

func sshServer(hostPrivateKey []byte, shellCommand string,
	authorizedPublicKeys *AuthorizedPublicKeys) *ssh.Server {
	ssh.Handle(func(sshSession ssh.Session) {
		workingDirectory, _ := os.Getwd()
		env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
		process, err := terminal.PtySpawner.Start(sshSession, env, shellCommand)
		if err != nil {
			panic(err)
		}
		sessionInfo := comms.NewSessionInfo(
			sshSession.LocalAddr().String(), "ssh",
		)
		session.Login(sessionInfo, sshSession)
		// For SSH we detect activity when there are writes to the shell that was spanwedn.
		// This means user input.
		activityDetector := NewWriteDetector(process.Pipe())
		iowrappers.SynchronizeStreams("shell -- ssh", activityDetector, sshSession)
		session.LogOut(sessionInfo.ClientId)
		// will cause addition goroutines to remmain alive when the SSH
		// session is killed. For now acceptable since the agent is a short-lived
		// process. Using Kill() here will create defunct processes and in normal
		// circummstances Wait() will be the best because the process will be shutting
		// down automatically becuase it has lost its terminal.
		process.Wait()
	})

	log.Println("starting ssh server, waiting for debug sessions")

	server := ssh.Server{
		PublicKeyHandler: authorizedPublicKeys.authorize,
		SubsystemHandlers: map[string]ssh.SubsystemHandler{
			"sftp": SftpHandler,
		},
	}
	option := ssh.HostKeyPEM(hostPrivateKey)
	option(&server)

	return &server
}

func echoServer(conn io.ReadWriter) {
	log.Println("Echo service started")
	io.Copy(conn, conn)
}

func netCatServer(conn io.ReadWriter) {
	stdio := bufio.NewReadWriter(
		bufio.NewReaderSize(os.Stdin, 0),
		bufio.NewWriterSize(os.Stdout, 0))
	iowrappers.SynchronizeStreams("stdio -- ws", conn, stdio)
}

type AgentService interface {
	Run(listener net.Listener)
}

type ListenerServer func() *ssh.Server

func (server ListenerServer) Run(listener net.Listener) {
	server().Serve(listener)
}

type ConnectionServer func(conn io.ReadWriter)

func (server ConnectionServer) Run(listener net.Listener) {
	for {
		conn, err := listener.Accept()
		if err != nil {
			panic(err)
		}
		go server(conn)
	}
}

type ReaderFunc func(p []byte) (n int, err error)

func (f ReaderFunc) Read(p []byte) (n int, err error) {
	return f(p)
}

func validateString(value, description, pattern string) {
	matched, err := regexp.MatchString(pattern, value)
	if err != nil || !matched {
		printHelp(fmt.Sprintf("%s: wrong value '%s', must conform to pattern '%s'",
			description, value, pattern))
	}
}

func getId(id string) string {
	if id == "" {
		// not specified
		return strconv.Itoa(time.Now().Nanosecond() % 1000000000)
	}
	validateString(id, "id", `^[a-zA-Z0-9][a-zA-Z0-9-]+$`)
	return id
}

func printHelp(msg string) {
	if msg != "" {
		fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", msg)
	}
	helpText := "agent [options] <wsUrl> \n" +
		"\n" +
		"Run agent with <wsUrl> of the form ws[s]://<host>[:port]\n" +
		"Here <ID> is the unique id of the agent that allows rendez-vous with an end-user.\n" +
		"The end-user must specify the same id when connecting using ssh.\n" +
		"\n" +
		"--id:               rendez-vous id, this is the id used to connect agents and clients. \n" +
		"--authorized-keys:  SSH authorized keys file in openssh format. By default .authorized_keys in the\n" +
		"                    directory where the agent is started is used.\n" +
		"--warning-time:     advance warning time before sessio ends (default '5m')\n" +
		"--expiry-time:      expiry time of the session (default '10m')\n" +
		"--check-interval:   interval at which expiry is checked\n" +
		"--insecure:         allow invalid certificates\n" +
		"--shells:           comma-separated list of shells to add to the front of theshell search path\n" +
		"                    (e.g. 'zsh,sh'). If the shell name contains a slash,then the path must exist:\n" +
		"                    either relative to the agent's current directory or absolute. Otherwise it is looekd\n" +
		"                    up in the system search path. "

	fmt.Fprintln(os.Stderr, helpText)
	os.Exit(1)
}

func getArg(args []string) (value string, ret []string) {
	if len(args) < 2 {
		printHelp(fmt.Sprintf("The '%s' option expects an argument", args[0]))
	}
	return args[1], args[1:]
}

func parseDuration(args []string) (time.Duration, []string) {
	arg, args := getArg(args)
	duration, err := time.ParseDuration(arg)
	if err != nil {
		printHelp(fmt.Sprintf("Error parsing duration: %v\n", err))
	}
	return duration, args
}

func main() {

	pprofPort, ok := os.LookupEnv("PPROF_PORT")
	if ok {
		log.Printf("Enabllng pprof on localhost:%s", pprofPort)
		go func() {
			log.Println(http.ListenAndServe("localhost:"+pprofPort, nil))
		}()
	}

	id := ""
	authorizedKeysFile := ".authorized_keys"
	advanceWarningTime := 5 * time.Minute
	agentExpriryTime := 10 * time.Minute
	tickerInterval := 60 * time.Second
	insecure := false
	shells := []string{"bash", "sh", "ash", "ksh", "zsh", "fish", "tcsh", "csh"}
	if runtime.GOOS == "windows" {
		shells = []string{"powershell", "bash"}
	}

	args := os.Args[1:]
	additionalShells := []string{}
	commaSeparated := ""
	for len(args) > 0 && strings.HasPrefix(args[0], "-") {
		switch args[0] {
		case "--id":
			id, args = getArg(args)
		case "--authorized-keys":
			authorizedKeysFile, args = getArg(args)
		case "--warning-time":
			advanceWarningTime, args = parseDuration(args)
		case "--expiry-time":
			agentExpriryTime, args = parseDuration(args)
		case "--check-interval":
			tickerInterval, args = parseDuration(args)
		case "--insecure":
			insecure = true
		case "--shells":
			commaSeparated, args = getArg(args)
			additionalShells = append(additionalShells, strings.Split(commaSeparated, ",")...)
		default:
			printHelp("Unknown option " + args[0])
		}
		args = args[1:]
	}

	if 2*advanceWarningTime > agentExpriryTime {
		printHelp("The warning time should be at most half the expiry time")
	}
	if 4*tickerInterval > agentExpriryTime {
		printHelp("The check interval should be at most 1/4 of the agent interval")
	}

	shells = append(additionalShells, shells...)

	id = getId(id)

	if len(args) != 1 {
		printHelp("")
	}

	wsURL := args[0]
	url, err := url.Parse(wsURL)
	if err != nil {
		printHelp(fmt.Sprintf("Invalid URL %s", wsURL))
	}
	wsURL += "/agent/" + id

	dialer := websocket.Dialer{
		Proxy:            http.ProxyFromEnvironment,
		HandshakeTimeout: 45 * time.Second,
	}
	if insecure {
		dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
	}

	// Authentiocation

	authorizedKeys, err := NewAuthorizedPublicKeys(authorizedKeysFile)
	if err != nil {
		os.Exit(1)
	}

	// Connect to server

	conn, _, err := dialer.Dial(wsURL, nil)
	if err != nil {
		log.Println("WebSocket connection error:", err)
		return
	}
	wsConn := websocketutil.NewWebSocketConn(conn, false)
	defer wsConn.Close()

	shell := chooseShell(shells)
	_, err = comms.AgentInitialization(wsConn, comms.NewEnvironmentInfo(shell))
	if err != nil {
		log.Printf("ERROR: %v", err)
		os.Exit(1)
	}

	registration, err := comms.ReceiveRegistrationMessage(wsConn)
	if err != nil {
		log.Printf("ERROR: %v", err)
		os.Exit(1)
	}
	log.Println("Server responded with: ", registration.Message)

	if registration.Id != id {
		log.Println("==============================================================================")
		log.Println("Duplicate agent id detected: the server allocated a new id to be used instead.")
		log.Println("")
		log.Println(registration.Id)
		log.Println("==============================================================================")
	}
	clientUrl := args[0] + "/client/" + registration.Id

	commChannel, err := comms.NewCommChannel(comms.Agent, wsConn)
	if err != nil {
		panic(err)
	}

	// initial check

	go comms.ListenForServerEvents(commChannel)

	var service AgentService

	service = ListenerServer(func() *ssh.Server {
		return sshServer(registration.HostPrivateKey, shell, authorizedKeys)
	})
	//service = ConnectionServer(netCatServer)
	//service = ConnectionServer(echoServer)
	log.Println()
	log.Printf("Clients should use the following commands to connect to this agent:")
	log.Println()
	sshCommand := fmt.Sprintf("ssh -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\"  localhost",
		clientUrl)
	sftpCommand := fmt.Sprintf("sftp -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" localhost",
		clientUrl)
	log.Println("  # For SSH")
	log.Println("  " + sshCommand)
	log.Println()
	log.Println("  # for SFTP")
	log.Println("  " + sftpCommand)
	log.Println()
	urlObject, _ := url.Parse(wsURL)

	extension := ""
	if runtime.GOOS == "windows" {
		extension = ".exe"
	}
	log.Printf("wsproxy can be downloaded from %s",
		strings.ReplaceAll(urlObject.Scheme, "ws", "http")+
			"://"+urlObject.Host+"/static/wsproxy"+extension)
	log.Println()

	session.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval)

	listener := comms.NewAgentListener(commChannel.Session)

	service.Run(listener)
}

func chooseShell(shells []string) string {
	log.Printf("Shell search path is %v", shells)
	var err error
	shell := ""
	for _, candidate := range shells {
		shell, err = exec.LookPath(candidate)
		if err == nil {
			break
		}
	}
	if shell == "" {
		log.Printf("Cannot find a shell in %v", shells)
		os.Exit(1)
	}
	log.Printf("Using shell %s for remote sessions", shell)
	return shell
}