converge/cmd/agent/agent.go
Erik Brakkee 3c803d6125 removed password based access
authorized keys can now be modified within the session.
keep last set of keys when no valid keys were found and keys are changed during the session .
2024-08-06 22:03:36 +02:00

397 lines
11 KiB
Go
Executable File

package main
import (
"bufio"
"converge/pkg/agent/session"
"converge/pkg/agent/terminal"
"converge/pkg/comms"
"converge/pkg/support/iowrappers"
"converge/pkg/support/websocketutil"
"crypto/tls"
"fmt"
"github.com/gliderlabs/ssh"
"github.com/gorilla/websocket"
"github.com/pkg/sftp"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"regexp"
"runtime"
"strconv"
"strings"
"time"
_ "embed"
_ "net/http/pprof"
)
//go:embed hostkey.pem
var hostPrivateKey []byte
func SftpHandler(sess ssh.Session) {
sessionInfo := comms.NewSessionInfo(
sess.LocalAddr().String(),
"sftp",
)
session.Login(sessionInfo, sess)
defer session.LogOut(sessionInfo.ClientId)
debugStream := io.Discard
serverOptions := []sftp.ServerOption{
sftp.WithDebug(debugStream),
}
server, err := sftp.NewServer(
sess,
serverOptions...,
)
if err != nil {
log.Printf("sftp tcpserver init error: %s\n", err)
return
}
if err := server.Serve(); err == io.EOF {
server.Close()
fmt.Println("sftp client exited session.")
} else if err != nil {
fmt.Println("sftp tcpserver completed with error:", err)
}
}
type UserActivityDetector struct {
session io.ReadWriter
}
func (user UserActivityDetector) Read(p []byte) (int, error) {
n, err := user.session.Read(p)
if err == nil && n > 0 {
session.UserActivityDetected()
}
return n, err
}
func (user UserActivityDetector) Write(p []byte) (int, error) {
return user.session.Write(p)
}
func sshServer(hostKeyFile string, shellCommand string,
authorizedPublicKeys *AuthorizedPublicKeys) *ssh.Server {
ssh.Handle(func(sshSession ssh.Session) {
workingDirectory, _ := os.Getwd()
env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
process, err := terminal.PtySpawner.Start(sshSession, env, shellCommand)
if err != nil {
panic(err)
}
sessionInfo := comms.NewSessionInfo(
sshSession.LocalAddr().String(), "ssh",
)
session.Login(sessionInfo, sshSession)
activityDetector := UserActivityDetector{
session: sshSession,
}
iowrappers.SynchronizeStreams("shell -- ssh", process.Pipe(), activityDetector)
session.LogOut(sessionInfo.ClientId)
// will cause addition goroutines to remmain alive when the SSH
// session is killed. For now acceptable since the agent is a short-lived
// process. Using Kill() here will create defunct processes and in normal
// circummstances Wait() will be the best because the process will be shutting
// down automatically becuase it has lost its terminal.
process.Wait()
})
log.Println("starting ssh server, waiting for debug sessions")
server := ssh.Server{
PublicKeyHandler: authorizedPublicKeys.authorize,
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": SftpHandler,
},
}
option := ssh.HostKeyPEM(hostPrivateKey)
option(&server)
return &server
}
func echoServer(conn io.ReadWriter) {
log.Println("Echo service started")
io.Copy(conn, conn)
}
func netCatServer(conn io.ReadWriter) {
stdio := bufio.NewReadWriter(
bufio.NewReaderSize(os.Stdin, 0),
bufio.NewWriterSize(os.Stdout, 0))
iowrappers.SynchronizeStreams("stdio -- ws", conn, stdio)
}
type AgentService interface {
Run(listener net.Listener)
}
type ListenerServer func() *ssh.Server
func (server ListenerServer) Run(listener net.Listener) {
server().Serve(listener)
}
type ConnectionServer func(conn io.ReadWriter)
func (server ConnectionServer) Run(listener net.Listener) {
for {
conn, err := listener.Accept()
if err != nil {
panic(err)
}
go server(conn)
}
}
type ReaderFunc func(p []byte) (n int, err error)
func (f ReaderFunc) Read(p []byte) (n int, err error) {
return f(p)
}
func validateString(value, description, pattern string) {
matched, err := regexp.MatchString(pattern, value)
if err != nil || !matched {
printHelp(fmt.Sprintf("%s: wrong value '%s', must conform to pattern '%s'",
description, value, pattern))
}
}
func getId(id string) string {
if id == "" {
// not specified
return strconv.Itoa(time.Now().Nanosecond() % 1000000000)
}
validateString(id, "id", `^[a-zA-Z0-9][a-zA-Z0-9-]+$`)
return id
}
func printHelp(msg string) {
if msg != "" {
fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", msg)
}
helpText := "agent [options] <wsUrl> \n" +
"\n" +
"Run agent with <wsUrl> of the form ws[s]://<host>[:port]\n" +
"Here <ID> is the unique id of the agent that allows rendez-vous with an end-user.\n" +
"The end-user must specify the same id when connecting using ssh.\n" +
"\n" +
"--id: rendez-vous id, this is the id used to connect agents and clients. \n" +
"--authorized-keys: SSH authorized keys file in openssh format. By default .authorized_keys in the\n" +
" directory where the agent is started is used.\n" +
"--warning-time: advance warning time before sessio ends (default '5m')\n" +
"--expiry-time: expiry time of the session (default '10m')\n" +
"--check-interval: interval at which expiry is checked\n" +
"--insecure: allow invalid certificates\n" +
"--shells: comma-separated list of shells to add to the front of theshell search path\n" +
" (e.g. 'zsh,sh'). If the shell name contains a slash,then the path must exist:\n" +
" either relative to the agent's current directory or absolute. Otherwise it is looekd\n" +
" up in the system search path. "
fmt.Fprintln(os.Stderr, helpText)
os.Exit(1)
}
func getArg(args []string) (value string, ret []string) {
if len(args) < 2 {
printHelp(fmt.Sprintf("The '%s' option expects an argument", args[0]))
}
return args[1], args[1:]
}
func parseDuration(args []string) (time.Duration, []string) {
arg, args := getArg(args)
duration, err := time.ParseDuration(arg)
if err != nil {
printHelp(fmt.Sprintf("Error parsing duration: %v\n", err))
}
return duration, args
}
func main() {
pprofPort, ok := os.LookupEnv("PPROF_PORT")
if ok {
log.Printf("Enabllng pprof on localhost:%s", pprofPort)
go func() {
log.Println(http.ListenAndServe("localhost:"+pprofPort, nil))
}()
}
id := ""
authorizedKeysFile := ".authorized_keys"
advanceWarningTime := 5 * time.Minute
agentExpriryTime := 10 * time.Minute
tickerInterval := 60 * time.Second
insecure := false
shells := []string{"bash", "sh", "ash", "ksh", "zsh", "fish", "tcsh", "csh"}
if runtime.GOOS == "windows" {
shells = []string{"powershell", "bash"}
}
args := os.Args[1:]
additionalShells := []string{}
commaSeparated := ""
for len(args) > 0 && strings.HasPrefix(args[0], "-") {
switch args[0] {
case "--id":
id, args = getArg(args)
case "--authorized-keys":
authorizedKeysFile, args = getArg(args)
case "--warning-time":
advanceWarningTime, args = parseDuration(args)
case "--expiry-time":
agentExpriryTime, args = parseDuration(args)
case "--check-interval":
tickerInterval, args = parseDuration(args)
case "--insecure":
insecure = true
case "--shells":
commaSeparated, args = getArg(args)
additionalShells = append(additionalShells, strings.Split(commaSeparated, ",")...)
default:
printHelp("Unknown option " + args[0])
}
args = args[1:]
}
if 2*advanceWarningTime > agentExpriryTime {
printHelp("The warning time should be at most half the expiry time")
}
if 4*tickerInterval > agentExpriryTime {
printHelp("The check interval should be at most 1/4 of the agent interval")
}
shells = append(additionalShells, shells...)
id = getId(id)
if len(args) != 1 {
printHelp("")
}
wsURL := args[0]
url, err := url.Parse(wsURL)
if err != nil {
printHelp(fmt.Sprintf("Invalid URL %s", wsURL))
}
wsURL += "/agent/" + id
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
}
if insecure {
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
conn, _, err := dialer.Dial(wsURL, nil)
if err != nil {
log.Println("WebSocket connection error:", err)
return
}
wsConn := websocketutil.NewWebSocketConn(conn, false)
defer wsConn.Close()
shell := chooseShell(shells)
serverInfo, err := comms.AgentInitialization(wsConn, comms.NewAgentInfo(shell))
if err != nil {
log.Printf("ERROR: %v", err)
os.Exit(1)
}
registration, err := comms.ReceiveRegistrationMessage(wsConn)
if err != nil {
log.Printf("ERROR: %v", err)
os.Exit(1)
}
log.Println("Server responded with: ", registration.Message)
if registration.Id != id {
log.Println("==============================================================================")
log.Println("Duplicate agent id detected: the server allocated a new id to be used instead.")
log.Println("")
log.Println(registration.Id)
log.Println("==============================================================================")
}
clientUrl := args[0] + "/client/" + registration.Id
commChannel, err := comms.NewCommChannel(comms.Agent, wsConn)
if err != nil {
panic(err)
}
// Authentiocation
authorizedKeys := NewAuthorizedPublicKeys(authorizedKeysFile)
// initial check
pubkeys := authorizedKeys.Parse()
if len(pubkeys) == 0 {
log.Printf("No public keys found in '%s', exiting", authorizedKeysFile)
os.Exit(1)
}
go comms.ListenForServerEvents(commChannel)
var service AgentService
service = ListenerServer(func() *ssh.Server {
return sshServer("hostkey.pem", shell, authorizedKeys)
})
//service = ConnectionServer(netCatServer)
//service = ConnectionServer(echoServer)
log.Println()
log.Printf("Clients should use the following commands to connect to this agent:")
log.Println()
sshCommand := fmt.Sprintf("ssh -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost",
clientUrl, serverInfo.UserPassword.Username)
sftpCommand := fmt.Sprintf("sftp -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost",
clientUrl, serverInfo.UserPassword.Username)
log.Println(" # For SSH")
log.Println(" " + sshCommand)
log.Println()
log.Println(" # for SFTP")
log.Println(" " + sftpCommand)
log.Println()
urlObject, _ := url.Parse(wsURL)
extension := ""
if runtime.GOOS == "windows" {
extension = ".exe"
}
log.Printf("wsproxy can be downloaded from %s",
strings.ReplaceAll(urlObject.Scheme, "ws", "http")+
"://"+urlObject.Host+"/static/wsproxy"+extension)
log.Println()
session.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval)
listener := comms.NewAgentListener(commChannel.Session)
service.Run(listener)
}
func chooseShell(shells []string) string {
log.Printf("Shell search path is %v", shells)
var err error
shell := ""
for _, candidate := range shells {
shell, err = exec.LookPath(candidate)
if err == nil {
break
}
}
if shell == "" {
log.Printf("Cannot find a shell in %v", shells)
os.Exit(1)
}
log.Printf("Using shell %s for remote sessions", shell)
return shell
}