diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 9a9ad7d..39752c1 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -57,14 +57,9 @@ func SftpHandler(sess ssh.Session) { } } -var sshUserCredentials = comms.UserPassword{} - -func passwordAuth(ctx ssh.Context, password string) bool { - // Replace with your own logic to validate username and password - return ctx.User() == sshUserCredentials.Username && password == sshUserCredentials.Password -} - -func sshServer(hostKeyFile string, shellCommand string) *ssh.Server { +func sshServer(hostKeyFile string, shellCommand string, + passwordHandler ssh.PasswordHandler, + authorizedPublicKeys AuthorizedPublicKeys) *ssh.Server { ssh.Handle(func(s ssh.Session) { workingDirectory, _ := os.Getwd() env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory)) @@ -81,18 +76,14 @@ func sshServer(hostKeyFile string, shellCommand string) *ssh.Server { }) log.Println("starting ssh server, waiting for debug sessions") + server := ssh.Server{ - PasswordHandler: passwordAuth, + PasswordHandler: passwordHandler, + PublicKeyHandler: authorizedPublicKeys.authorize, SubsystemHandlers: map[string]ssh.SubsystemHandler{ "sftp": SftpHandler, }, } - //err := generateHostKey(hostKeyFile, 2048) - //if err != nil { - // log.Printf("Could not create host key file '%s': %v", hostKeyFile, err) - //} - //option := ssh.HostKeyFile(hostKeyFile) - option := ssh.HostKeyPEM(hostPrivateKey) option(&server) @@ -152,7 +143,7 @@ func getId(id string) string { // not specified return strconv.Itoa(time.Now().Nanosecond() % 1000000000) } - validateString(id, "id", `^[a-zA-Z0-9-]+$`) + validateString(id, "id", `^[a-zA-Z0-9][a-zA-Z0-9-]+$`) return id } @@ -166,13 +157,15 @@ func printHelp(msg string) { "Here is the unique id of the agent that allows rendez-vous with an end-user.\n" + "The end-user must specify the same id when connecting using ssh.\n" + "\n" + - "--id: Rendez-vous id. When specified an SSH authorized key must be used and password\n" + - " based access is disabled. When not specified a random id is chosen by the agent and\n" + - " password based access is possible. The password is configured on the converge server\n" + - "--warning-time: advance warning time before sessio ends\n" + - "--expiry-time: expiry time of the session\n" + + "--id: rendez-vous id. When specified an SSH authorized key must be used and password\n" + + " based access is disabled. When not specified a random id is chosen by the agent and\n" + + " password based access is possible. The password is configured on the converge server\n" + + "--ssh-keys-file: SSH authorized keys file in openssh format. By default .authorized_keys in the\n" + + " directory where the agent is started is used.\n" + + "--warning-time: advance warning time before sessio ends\n" + + "--expiry-time: expiry time of the session\n" + "--check-interval: interval at which expiry is checked\n" + - "-insecure: allow invalid certificates\n" + "-insecure: allow invalid certificates\n" fmt.Fprintln(os.Stderr, helpText) os.Exit(1) @@ -195,14 +188,8 @@ func parseDuration(args []string, val string) (time.Duration, []string) { func main() { - // Random user name and password so that effectively no one can login - // until the user and password have been received from the server. - sshUserCredentials = comms.UserPassword{ - Username: strconv.Itoa(rand.Int()), - Password: strconv.Itoa(rand.Int()), - } - id := "" + authorizedKeysFile := ".authorized_keys" advanceWarningTime := 5 * time.Minute agentExpriryTime := 10 * time.Minute tickerInterval := 60 * time.Second @@ -214,6 +201,8 @@ func main() { switch args[0] { case "--id": id, args = getArg(args) + case "--ssh-keys-file": + authorizedKeysFile, args = getArg(args) case "--warning-time": advanceWarningTime, args = parseDuration(args, val) case "--expiry-time": @@ -257,31 +246,18 @@ func main() { panic(err) } - go comms.ListenForServerEvents(commChannel, func(user comms.UserPassword) { - log.Println("Username and password configuration received from server") - sshUserCredentials = user - }) + // Authentiocation + + sshUserCredentials, passwordHandler, authorizedKeys := setupAuthentication(commChannel, authorizedKeysFile) + + // Choose shell + + shell := chooseShell() var service AgentService - shells := []string{"bash", "sh", "ash", "ksh", "zsh", "fish", "tcsh", "csh"} - if runtime.GOOS == "windows" { - shells = []string{"powershell", "bash"} - } - shell := "" - for _, candidate := range shells { - shell, err = exec.LookPath(candidate) - if err == nil { - break - } - } - if shell == "" { - log.Printf("Cannot find a shell in %v", shells) - os.Exit(1) - } - log.Printf("Using shell %s for remote sessions", shell) service = ListenerServer(func() *ssh.Server { - return sshServer("hostkey.pem", shell) + return sshServer("hostkey.pem", shell, passwordHandler, authorizedKeys) }) //service = ConnectionServer(netCatServer) //service = ConnectionServer(echoServer) @@ -309,3 +285,48 @@ func main() { agent.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval) service.Run(commChannel.Session) } + +func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) { + // Random user name and password so that effectively no one can login + // until the user and password have been received from the server. + sshUserCredentials := comms.UserPassword{ + Username: strconv.Itoa(rand.Int()), + Password: strconv.Itoa(rand.Int()), + } + passwordHandler := func(ctx ssh.Context, password string) bool { + // Replace with your own logic to validate username and password + return ctx.User() == sshUserCredentials.Username && password == sshUserCredentials.Password + } + go comms.ListenForServerEvents(commChannel, func(user comms.UserPassword) { + log.Println("Username and password configuration received from server") + sshUserCredentials = user + }) + authorizedKeys := ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile) + if len(authorizedKeys.keys) > 0 { + log.Printf("A total of %d authorized ssh keys were found", len(authorizedKeys.keys)) + } + return sshUserCredentials, passwordHandler, authorizedKeys +} + +func chooseShell() string { + var err error + + shells := []string{"bash", "sh", "ash", "ksh", "zsh", "fish", "tcsh", "csh"} + if runtime.GOOS == "windows" { + shells = []string{"powershell", "bash"} + } + + shell := "" + for _, candidate := range shells { + shell, err = exec.LookPath(candidate) + if err == nil { + break + } + } + if shell == "" { + log.Printf("Cannot find a shell in %v", shells) + os.Exit(1) + } + log.Printf("Using shell %s for remote sessions", shell) + return shell +} diff --git a/cmd/agent/sshauthorizedkeys.go b/cmd/agent/sshauthorizedkeys.go new file mode 100644 index 0000000..d2b87d8 --- /dev/null +++ b/cmd/agent/sshauthorizedkeys.go @@ -0,0 +1,81 @@ +package main + +import ( + "bufio" + "fmt" + "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" + "log" + "os" + "strings" +) + +func publicKeyHandler(ctx ssh.Context, key gossh.PublicKey, authorizedKey gossh.PublicKey) bool { + providedKey := gossh.MarshalAuthorizedKey(key) + + if ssh.KeysEqual(key, authorizedKey) { + log.Printf("Successful login from %s", ctx.RemoteAddr()) + return true + } + + log.Printf("Failed login attempt from %s with key: %s", ctx.RemoteAddr(), strings.TrimSpace(string(providedKey))) + return false +} + +func readSshPublicKeys(fileName string) ([]ssh.PublicKey, error) { + file, err := os.Open(fileName) + if err != nil { + return nil, fmt.Errorf("Failed to open file: '%s': %s", fileName, err) + } + defer file.Close() + + res := make([]ssh.PublicKey, 10) + scanner := bufio.NewScanner(file) + for scanner.Scan() { + lineText := scanner.Text() + ind := strings.Index(lineText, "#") + if ind >= 0 { + lineText = lineText[:ind] + } + lineText = strings.Trim(lineText, "") + if lineText == "" { + continue + } + line := []byte(lineText) + parsedKey, _, _, _, err := ssh.ParseAuthorizedKey(line) + if err != nil { + log.Printf("Failed to parse authorized key: %v", lineText) + } else { + res = append(res, parsedKey) + } + } + return res, nil +} + +type AuthorizedPublicKeys struct { + keys []ssh.PublicKey +} + +func ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile string) AuthorizedPublicKeys { + if authorizedKeysFile == "" { + return AuthorizedPublicKeys{} + } + keys, err := readSshPublicKeys(authorizedKeysFile) + if os.IsNotExist(err) { + log.Printf("Authorized keys file '%s' not found.", authorizedKeysFile) + return AuthorizedPublicKeys{} + } + if err != nil { + log.Println("Public key authentication will not work since no public keys were found.") + } + return AuthorizedPublicKeys{keys: keys} +} + +func (key AuthorizedPublicKeys) authorize(ctx ssh.Context, userProvidedKey ssh.PublicKey) bool { + for _, key := range key.keys { + if publicKeyHandler(ctx, userProvidedKey, key) { + return true + } + } + return false +}