package main import ( "bufio" "converge/pkg/agent/session" "converge/pkg/agent/terminal" "converge/pkg/comms" "converge/pkg/support/iowrappers" "converge/pkg/support/websocketutil" "crypto/tls" "fmt" "github.com/gliderlabs/ssh" "github.com/gorilla/websocket" "github.com/pkg/sftp" "io" "log" "net" "net/http" "net/url" "os" "os/exec" "regexp" "runtime" "strconv" "strings" "time" _ "embed" _ "net/http/pprof" ) //go:embed hostkey.pem var hostPrivateKey []byte func SftpHandler(sess ssh.Session) { sessionInfo := comms.NewSessionInfo( sess.LocalAddr().String(), "sftp", ) session.Login(sessionInfo, sess) defer session.LogOut(sessionInfo.ClientId) debugStream := io.Discard serverOptions := []sftp.ServerOption{ sftp.WithDebug(debugStream), } server, err := sftp.NewServer( sess, serverOptions..., ) if err != nil { log.Printf("sftp tcpserver init error: %s\n", err) return } if err := server.Serve(); err == io.EOF { server.Close() fmt.Println("sftp client exited session.") } else if err != nil { fmt.Println("sftp tcpserver completed with error:", err) } } type UserActivityDetector struct { session io.ReadWriter } func (user UserActivityDetector) Read(p []byte) (int, error) { n, err := user.session.Read(p) if err == nil && n > 0 { session.UserActivityDetected() } return n, err } func (user UserActivityDetector) Write(p []byte) (int, error) { return user.session.Write(p) } func sshServer(hostKeyFile string, shellCommand string, passwordHandler ssh.PasswordHandler, authorizedPublicKeys AuthorizedPublicKeys) *ssh.Server { ssh.Handle(func(sshSession ssh.Session) { workingDirectory, _ := os.Getwd() env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory)) process, err := terminal.PtySpawner.Start(sshSession, env, shellCommand) if err != nil { panic(err) } sessionInfo := comms.NewSessionInfo( sshSession.LocalAddr().String(), "ssh", ) session.Login(sessionInfo, sshSession) activityDetector := UserActivityDetector{ session: sshSession, } iowrappers.SynchronizeStreams("shell -- ssh", process.Pipe(), activityDetector) session.LogOut(sessionInfo.ClientId) // will cause addition goroutines to remmain alive when the SSH // session is killed. For now acceptable since the agent is a short-lived // process. Using Kill() here will create defunct processes and in normal // circummstances Wait() will be the best because the process will be shutting // down automatically becuase it has lost its terminal. process.Wait() }) log.Println("starting ssh server, waiting for debug sessions") server := ssh.Server{ PasswordHandler: passwordHandler, PublicKeyHandler: authorizedPublicKeys.authorize, SubsystemHandlers: map[string]ssh.SubsystemHandler{ "sftp": SftpHandler, }, } option := ssh.HostKeyPEM(hostPrivateKey) option(&server) return &server } func echoServer(conn io.ReadWriter) { log.Println("Echo service started") io.Copy(conn, conn) } func netCatServer(conn io.ReadWriter) { stdio := bufio.NewReadWriter( bufio.NewReaderSize(os.Stdin, 0), bufio.NewWriterSize(os.Stdout, 0)) iowrappers.SynchronizeStreams("stdio -- ws", conn, stdio) } type AgentService interface { Run(listener net.Listener) } type ListenerServer func() *ssh.Server func (server ListenerServer) Run(listener net.Listener) { server().Serve(listener) } type ConnectionServer func(conn io.ReadWriter) func (server ConnectionServer) Run(listener net.Listener) { for { conn, err := listener.Accept() if err != nil { panic(err) } go server(conn) } } type ReaderFunc func(p []byte) (n int, err error) func (f ReaderFunc) Read(p []byte) (n int, err error) { return f(p) } func validateString(value, description, pattern string) { matched, err := regexp.MatchString(pattern, value) if err != nil || !matched { printHelp(fmt.Sprintf("%s: wrong value '%s', must conform to pattern '%s'", description, value, pattern)) } } func getId(id string) string { if id == "" { // not specified return strconv.Itoa(time.Now().Nanosecond() % 1000000000) } validateString(id, "id", `^[a-zA-Z0-9][a-zA-Z0-9-]+$`) return id } func printHelp(msg string) { if msg != "" { fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", msg) } helpText := "agent [options] \n" + "\n" + "Run agent with of the form ws[s]://[:port]\n" + "Here is the unique id of the agent that allows rendez-vous with an end-user.\n" + "The end-user must specify the same id when connecting using ssh.\n" + "\n" + "--id: rendez-vous id. When specified an SSH authorized key must be used and password\n" + " based access is disabled. When not specified a random id is chosen by the agent and\n" + " password based access is possible. The password is configured on the converge server\n" + "--authorized-keys: SSH authorized keys file in openssh format. By default .authorized_keys in the\n" + " directory where the agent is started is used.\n" + "--warning-time: advance warning time before sessio ends (default '5m')\n" + "--expiry-time: expiry time of the session (default '10m')\n" + "--check-interval: interval at which expiry is checked\n" + "--insecure: allow invalid certificates\n" + "--shells: comma-separated list of shells to add to the front of theshell search path\n" + " (e.g. 'zsh,sh'). If the shell name contains a slash,then the path must exist:\n" + " either relative to the agent's current directory or absolute. Otherwise it is looekd\n" + " up in the system search path. " fmt.Fprintln(os.Stderr, helpText) os.Exit(1) } func getArg(args []string) (value string, ret []string) { if len(args) < 2 { printHelp(fmt.Sprintf("The '%s' option expects an argument", args[0])) } return args[1], args[1:] } func parseDuration(args []string) (time.Duration, []string) { arg, args := getArg(args) duration, err := time.ParseDuration(arg) if err != nil { printHelp(fmt.Sprintf("Error parsing duration: %v\n", err)) } return duration, args } func main() { pprofPort, ok := os.LookupEnv("PPROF_PORT") if ok { log.Printf("Enabllng pprof on localhost:%s", pprofPort) go func() { log.Println(http.ListenAndServe("localhost:"+pprofPort, nil)) }() } id := "" authorizedKeysFile := ".authorized_keys" advanceWarningTime := 5 * time.Minute agentExpriryTime := 10 * time.Minute tickerInterval := 60 * time.Second insecure := false shells := []string{"bash", "sh", "ash", "ksh", "zsh", "fish", "tcsh", "csh"} if runtime.GOOS == "windows" { shells = []string{"powershell", "bash"} } args := os.Args[1:] additionalShells := []string{} commaSeparated := "" for len(args) > 0 && strings.HasPrefix(args[0], "-") { switch args[0] { case "--id": id, args = getArg(args) case "--authorized-keys": authorizedKeysFile, args = getArg(args) case "--warning-time": advanceWarningTime, args = parseDuration(args) case "--expiry-time": agentExpriryTime, args = parseDuration(args) case "--check-interval": tickerInterval, args = parseDuration(args) case "--insecure": insecure = true case "--shells": commaSeparated, args = getArg(args) additionalShells = append(additionalShells, strings.Split(commaSeparated, ",")...) default: printHelp("Unknown option " + args[0]) } args = args[1:] } if 2*advanceWarningTime > agentExpriryTime { printHelp("The warning time should be at most half the expiry time") } if 4*tickerInterval > agentExpriryTime { printHelp("The check interval should be at most 1/4 of the agent interval") } shells = append(additionalShells, shells...) id = getId(id) if len(args) != 1 { printHelp("") } wsURL := args[0] url, err := url.Parse(wsURL) if err != nil { printHelp(fmt.Sprintf("Invalid URL %s", wsURL)) } wsURL += "/agent/" + id dialer := websocket.Dialer{ Proxy: http.ProxyFromEnvironment, HandshakeTimeout: 45 * time.Second, } if insecure { dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} } conn, _, err := dialer.Dial(wsURL, nil) if err != nil { log.Println("WebSocket connection error:", err) return } wsConn := websocketutil.NewWebSocketConn(conn, false) defer wsConn.Close() shell := chooseShell(shells) serverInfo, err := comms.AgentInitialization(wsConn, comms.NewAgentInfo(shell)) if err != nil { log.Printf("ERROR: %v", err) os.Exit(1) } registration, err := comms.ReceiveRegistrationMessage(wsConn) if err != nil { log.Printf("ERROR: %v", err) os.Exit(1) } log.Println("Server responded with: ", registration.Message) if registration.Id != id { log.Println("==============================================================================") log.Println("Duplicate agent id detected: the server allocated a new id to be used instead.") log.Println("") log.Println(registration.Id) log.Println("==============================================================================") } clientUrl := args[0] + "/client/" + registration.Id commChannel, err := comms.NewCommChannel(comms.Agent, wsConn) if err != nil { panic(err) } // Authentiocation passwordHandler, authorizedKeys := setupAuthentication( commChannel, serverInfo.UserPassword, authorizedKeysFile) var service AgentService service = ListenerServer(func() *ssh.Server { return sshServer("hostkey.pem", shell, passwordHandler, authorizedKeys) }) //service = ConnectionServer(netCatServer) //service = ConnectionServer(echoServer) log.Println() log.Printf("Clients should use the following commands to connect to this agent:") log.Println() sshCommand := fmt.Sprintf("ssh -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost", clientUrl, serverInfo.UserPassword.Username) sftpCommand := fmt.Sprintf("sftp -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost", clientUrl, serverInfo.UserPassword.Username) log.Println(" # For SSH") log.Println(" " + sshCommand) log.Println() log.Println(" # for SFTP") log.Println(" " + sftpCommand) log.Println() urlObject, _ := url.Parse(wsURL) extension := "" if runtime.GOOS == "windows" { extension = ".exe" } log.Printf("wsproxy can be downloaded from %s", strings.ReplaceAll(urlObject.Scheme, "ws", "http")+ "://"+urlObject.Host+"/static/wsproxy"+extension) log.Println() session.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval) listener := comms.NewAgentListener(commChannel.Session) service.Run(listener) } func setupAuthentication(commChannel comms.CommChannel, userPassword comms.UserPassword, authorizedKeysFile string) (func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) { passwordHandler := func(ctx ssh.Context, password string) bool { // Replace with your own logic to validate username and password return ctx.User() == userPassword.Username && password == userPassword.Password } go comms.ListenForServerEvents(commChannel) authorizedKeys := ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile) if len(authorizedKeys.keys) > 0 { log.Printf("A total of %d authorized ssh keys were found", len(authorizedKeys.keys)) } return passwordHandler, authorizedKeys } func chooseShell(shells []string) string { log.Printf("Shell search path is %v", shells) var err error shell := "" for _, candidate := range shells { shell, err = exec.LookPath(candidate) if err == nil { break } } if shell == "" { log.Printf("Cannot find a shell in %v", shells) os.Exit(1) } log.Printf("Using shell %s for remote sessions", shell) return shell }