package main

import (
	"fmt"
	"log"
	"net"
	"os"
	"os/signal"
	"syscall"

	"golang.org/x/crypto/ssh"
	"golang.org/x/term"
)

func main() {
	if len(os.Args) != 3 {
		fmt.Println("Usage: go run ssh_client.go <user> <hostname:port>")
		os.Exit(1)
	}

	user := os.Args[1]
	address := os.Args[2]

	fmt.Print("Enter Password: ")
	password, err := term.ReadPassword(int(os.Stdin.Fd()))
	if err != nil {
		log.Fatalf("Failed to read password: %v", err)
	}
	fmt.Println()

	// Establish TCP connection
	netConn, err := net.Dial("tcp", address)
	if err != nil {
		log.Fatalf("Failed to connect to %s: %v", address, err)
	}

	// Use the connection to create an SSH client
	err = runSSHClient(user, string(password), netConn)
	if err != nil {
		log.Fatalf("SSH session failed: %v", err)
	}
}

func runSSHClient(user, password string, netConn net.Conn) error {
	config := &ssh.ClientConfig{
		User: user,
		Auth: []ssh.AuthMethod{
			ssh.Password(password),
		},
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
	}

	// Create SSH connection from net.Conn
	conn, chans, reqs, err := ssh.NewClientConn(netConn, netConn.RemoteAddr().String(), config)
	if err != nil {
		return fmt.Errorf("failed to create SSH client connection: %v", err)
	}

	client := ssh.NewClient(conn, chans, reqs)
	defer client.Close()

	session, err := client.NewSession()
	if err != nil {
		return fmt.Errorf("failed to create session: %v", err)
	}
	defer session.Close()

	modes := ssh.TerminalModes{
		ssh.ECHO:          1,
		ssh.TTY_OP_ISPEED: 14400,
		ssh.TTY_OP_OSPEED: 14400,
	}

	fd := int(os.Stdin.Fd())
	oldState, err := term.MakeRaw(fd)
	if err != nil {
		return fmt.Errorf("failed to set raw mode: %v", err)
	}
	defer term.Restore(fd, oldState)

	width, height, err := term.GetSize(fd)
	if err != nil {
		return fmt.Errorf("failed to get terminal size: %v", err)
	}

	if err := session.RequestPty("xterm", height, width, modes); err != nil {
		return fmt.Errorf("request for pseudo terminal failed: %v", err)
	}

	session.Stdout = os.Stdout
	session.Stderr = os.Stderr
	session.Stdin = os.Stdin

	if err := session.Shell(); err != nil {
		return fmt.Errorf("failed to start shell: %v", err)
	}

	// Handle window size changes
	go handleWindowChange(session, fd)

	// Handle Ctrl+C (SIGINT)
	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, syscall.SIGINT)
	go func() {
		for range sigChan {
			session.Signal(ssh.SIGINT)
		}
	}()

	if err := session.Wait(); err != nil {
		if e, ok := err.(*ssh.ExitError); ok {
			os.Exit(e.ExitStatus())
		}
		return fmt.Errorf("failed to wait for session: %v", err)
	}

	return nil
}