Some cleanup in the agent code.
Now supporting authorized SSH keys in the .authorized_keys file.
This commit is contained in:
		
							parent
							
								
									2ed81c3174
								
							
						
					
					
						commit
						a59011b00c
					
				| @ -57,14 +57,9 @@ func SftpHandler(sess ssh.Session) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| var sshUserCredentials = comms.UserPassword{} | ||||
| 
 | ||||
| func passwordAuth(ctx ssh.Context, password string) bool { | ||||
| 	// Replace with your own logic to validate username and password
 | ||||
| 	return ctx.User() == sshUserCredentials.Username && password == sshUserCredentials.Password | ||||
| } | ||||
| 
 | ||||
| func sshServer(hostKeyFile string, shellCommand string) *ssh.Server { | ||||
| func sshServer(hostKeyFile string, shellCommand string, | ||||
| 	passwordHandler ssh.PasswordHandler, | ||||
| 	authorizedPublicKeys AuthorizedPublicKeys) *ssh.Server { | ||||
| 	ssh.Handle(func(s ssh.Session) { | ||||
| 		workingDirectory, _ := os.Getwd() | ||||
| 		env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory)) | ||||
| @ -81,18 +76,14 @@ func sshServer(hostKeyFile string, shellCommand string) *ssh.Server { | ||||
| 	}) | ||||
| 
 | ||||
| 	log.Println("starting ssh server, waiting for debug sessions") | ||||
| 
 | ||||
| 	server := ssh.Server{ | ||||
| 		PasswordHandler: passwordAuth, | ||||
| 		PasswordHandler:  passwordHandler, | ||||
| 		PublicKeyHandler: authorizedPublicKeys.authorize, | ||||
| 		SubsystemHandlers: map[string]ssh.SubsystemHandler{ | ||||
| 			"sftp": SftpHandler, | ||||
| 		}, | ||||
| 	} | ||||
| 	//err := generateHostKey(hostKeyFile, 2048)
 | ||||
| 	//if err != nil {
 | ||||
| 	//	log.Printf("Could not create host key file '%s': %v", hostKeyFile, err)
 | ||||
| 	//}
 | ||||
| 	//option := ssh.HostKeyFile(hostKeyFile)
 | ||||
| 
 | ||||
| 	option := ssh.HostKeyPEM(hostPrivateKey) | ||||
| 	option(&server) | ||||
| 
 | ||||
| @ -152,7 +143,7 @@ func getId(id string) string { | ||||
| 		// not specified
 | ||||
| 		return strconv.Itoa(time.Now().Nanosecond() % 1000000000) | ||||
| 	} | ||||
| 	validateString(id, "id", `^[a-zA-Z0-9-]+$`) | ||||
| 	validateString(id, "id", `^[a-zA-Z0-9][a-zA-Z0-9-]+$`) | ||||
| 	return id | ||||
| } | ||||
| 
 | ||||
| @ -166,13 +157,15 @@ func printHelp(msg string) { | ||||
| 		"Here <ID> is the unique id of the agent that allows rendez-vous with an end-user.\n" + | ||||
| 		"The end-user must specify the same id when connecting using ssh.\n" + | ||||
| 		"\n" + | ||||
| 		"--id: Rendez-vous id. When specified an SSH authorized key must be used and password\n" + | ||||
| 		"      based access is disabled. When not specified a random id is chosen by the agent and\n" + | ||||
| 		"      password based access is possible. The password is configured on the converge server\n" + | ||||
| 		"--warning-time: advance warning time before sessio ends\n" + | ||||
| 		"--expiry-time: expiry time of the session\n" + | ||||
| 		"--id:             rendez-vous id. When specified an SSH authorized key must be used and password\n" + | ||||
| 		"                  based access is disabled. When not specified a random id is chosen by the agent and\n" + | ||||
| 		"                  password based access is possible. The password is configured on the converge server\n" + | ||||
| 		"--ssh-keys-file:  SSH authorized keys file in openssh format. By default .authorized_keys in the\n" + | ||||
| 		"                  directory where the agent is started is used.\n" + | ||||
| 		"--warning-time:   advance warning time before sessio ends\n" + | ||||
| 		"--expiry-time:    expiry time of the session\n" + | ||||
| 		"--check-interval: interval at which expiry is checked\n" + | ||||
| 		"-insecure: allow invalid certificates\n" | ||||
| 		"-insecure:        allow invalid certificates\n" | ||||
| 
 | ||||
| 	fmt.Fprintln(os.Stderr, helpText) | ||||
| 	os.Exit(1) | ||||
| @ -195,14 +188,8 @@ func parseDuration(args []string, val string) (time.Duration, []string) { | ||||
| 
 | ||||
| func main() { | ||||
| 
 | ||||
| 	// Random user name and password so that effectively no one can login
 | ||||
| 	// until the user and password have been received from the server.
 | ||||
| 	sshUserCredentials = comms.UserPassword{ | ||||
| 		Username: strconv.Itoa(rand.Int()), | ||||
| 		Password: strconv.Itoa(rand.Int()), | ||||
| 	} | ||||
| 
 | ||||
| 	id := "" | ||||
| 	authorizedKeysFile := ".authorized_keys" | ||||
| 	advanceWarningTime := 5 * time.Minute | ||||
| 	agentExpriryTime := 10 * time.Minute | ||||
| 	tickerInterval := 60 * time.Second | ||||
| @ -214,6 +201,8 @@ func main() { | ||||
| 		switch args[0] { | ||||
| 		case "--id": | ||||
| 			id, args = getArg(args) | ||||
| 		case "--ssh-keys-file": | ||||
| 			authorizedKeysFile, args = getArg(args) | ||||
| 		case "--warning-time": | ||||
| 			advanceWarningTime, args = parseDuration(args, val) | ||||
| 		case "--expiry-time": | ||||
| @ -257,31 +246,18 @@ func main() { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 
 | ||||
| 	go comms.ListenForServerEvents(commChannel, func(user comms.UserPassword) { | ||||
| 		log.Println("Username and password configuration received from server") | ||||
| 		sshUserCredentials = user | ||||
| 	}) | ||||
| 	// Authentiocation
 | ||||
| 
 | ||||
| 	sshUserCredentials, passwordHandler, authorizedKeys := setupAuthentication(commChannel, authorizedKeysFile) | ||||
| 
 | ||||
| 	// Choose shell
 | ||||
| 
 | ||||
| 	shell := chooseShell() | ||||
| 
 | ||||
| 	var service AgentService | ||||
| 	shells := []string{"bash", "sh", "ash", "ksh", "zsh", "fish", "tcsh", "csh"} | ||||
| 	if runtime.GOOS == "windows" { | ||||
| 		shells = []string{"powershell", "bash"} | ||||
| 	} | ||||
| 
 | ||||
| 	shell := "" | ||||
| 	for _, candidate := range shells { | ||||
| 		shell, err = exec.LookPath(candidate) | ||||
| 		if err == nil { | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if shell == "" { | ||||
| 		log.Printf("Cannot find a shell in %v", shells) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| 	log.Printf("Using shell %s for remote sessions", shell) | ||||
| 	service = ListenerServer(func() *ssh.Server { | ||||
| 		return sshServer("hostkey.pem", shell) | ||||
| 		return sshServer("hostkey.pem", shell, passwordHandler, authorizedKeys) | ||||
| 	}) | ||||
| 	//service = ConnectionServer(netCatServer)
 | ||||
| 	//service = ConnectionServer(echoServer)
 | ||||
| @ -309,3 +285,48 @@ func main() { | ||||
| 	agent.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval) | ||||
| 	service.Run(commChannel.Session) | ||||
| } | ||||
| 
 | ||||
| func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) { | ||||
| 	// Random user name and password so that effectively no one can login
 | ||||
| 	// until the user and password have been received from the server.
 | ||||
| 	sshUserCredentials := comms.UserPassword{ | ||||
| 		Username: strconv.Itoa(rand.Int()), | ||||
| 		Password: strconv.Itoa(rand.Int()), | ||||
| 	} | ||||
| 	passwordHandler := func(ctx ssh.Context, password string) bool { | ||||
| 		// Replace with your own logic to validate username and password
 | ||||
| 		return ctx.User() == sshUserCredentials.Username && password == sshUserCredentials.Password | ||||
| 	} | ||||
| 	go comms.ListenForServerEvents(commChannel, func(user comms.UserPassword) { | ||||
| 		log.Println("Username and password configuration received from server") | ||||
| 		sshUserCredentials = user | ||||
| 	}) | ||||
| 	authorizedKeys := ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile) | ||||
| 	if len(authorizedKeys.keys) > 0 { | ||||
| 		log.Printf("A total of %d authorized ssh keys were found", len(authorizedKeys.keys)) | ||||
| 	} | ||||
| 	return sshUserCredentials, passwordHandler, authorizedKeys | ||||
| } | ||||
| 
 | ||||
| func chooseShell() string { | ||||
| 	var err error | ||||
| 
 | ||||
| 	shells := []string{"bash", "sh", "ash", "ksh", "zsh", "fish", "tcsh", "csh"} | ||||
| 	if runtime.GOOS == "windows" { | ||||
| 		shells = []string{"powershell", "bash"} | ||||
| 	} | ||||
| 
 | ||||
| 	shell := "" | ||||
| 	for _, candidate := range shells { | ||||
| 		shell, err = exec.LookPath(candidate) | ||||
| 		if err == nil { | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if shell == "" { | ||||
| 		log.Printf("Cannot find a shell in %v", shells) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| 	log.Printf("Using shell %s for remote sessions", shell) | ||||
| 	return shell | ||||
| } | ||||
|  | ||||
							
								
								
									
										81
									
								
								cmd/agent/sshauthorizedkeys.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								cmd/agent/sshauthorizedkeys.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,81 @@ | ||||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"fmt" | ||||
| 	"github.com/gliderlabs/ssh" | ||||
| 	gossh "golang.org/x/crypto/ssh" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| func publicKeyHandler(ctx ssh.Context, key gossh.PublicKey, authorizedKey gossh.PublicKey) bool { | ||||
| 	providedKey := gossh.MarshalAuthorizedKey(key) | ||||
| 
 | ||||
| 	if ssh.KeysEqual(key, authorizedKey) { | ||||
| 		log.Printf("Successful login from %s", ctx.RemoteAddr()) | ||||
| 		return true | ||||
| 	} | ||||
| 
 | ||||
| 	log.Printf("Failed login attempt from %s with key: %s", ctx.RemoteAddr(), strings.TrimSpace(string(providedKey))) | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func readSshPublicKeys(fileName string) ([]ssh.PublicKey, error) { | ||||
| 	file, err := os.Open(fileName) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("Failed to open file: '%s': %s", fileName, err) | ||||
| 	} | ||||
| 	defer file.Close() | ||||
| 
 | ||||
| 	res := make([]ssh.PublicKey, 10) | ||||
| 	scanner := bufio.NewScanner(file) | ||||
| 	for scanner.Scan() { | ||||
| 		lineText := scanner.Text() | ||||
| 		ind := strings.Index(lineText, "#") | ||||
| 		if ind >= 0 { | ||||
| 			lineText = lineText[:ind] | ||||
| 		} | ||||
| 		lineText = strings.Trim(lineText, "") | ||||
| 		if lineText == "" { | ||||
| 			continue | ||||
| 		} | ||||
| 		line := []byte(lineText) | ||||
| 		parsedKey, _, _, _, err := ssh.ParseAuthorizedKey(line) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Failed to parse authorized key: %v", lineText) | ||||
| 		} else { | ||||
| 			res = append(res, parsedKey) | ||||
| 		} | ||||
| 	} | ||||
| 	return res, nil | ||||
| } | ||||
| 
 | ||||
| type AuthorizedPublicKeys struct { | ||||
| 	keys []ssh.PublicKey | ||||
| } | ||||
| 
 | ||||
| func ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile string) AuthorizedPublicKeys { | ||||
| 	if authorizedKeysFile == "" { | ||||
| 		return AuthorizedPublicKeys{} | ||||
| 	} | ||||
| 	keys, err := readSshPublicKeys(authorizedKeysFile) | ||||
| 	if os.IsNotExist(err) { | ||||
| 		log.Printf("Authorized keys file '%s' not found.", authorizedKeysFile) | ||||
| 		return AuthorizedPublicKeys{} | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		log.Println("Public key authentication will not work since no public keys were found.") | ||||
| 	} | ||||
| 	return AuthorizedPublicKeys{keys: keys} | ||||
| } | ||||
| 
 | ||||
| func (key AuthorizedPublicKeys) authorize(ctx ssh.Context, userProvidedKey ssh.PublicKey) bool { | ||||
| 	for _, key := range key.keys { | ||||
| 		if publicKeyHandler(ctx, userProvidedKey, key) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user