converge/cmd/agent/agent.go
Erik Brakkee 621bbd8ca6 GOB channel for easily and asynchronously using GOB on a single network connection, also dealing with timeouts and errors in a good way.
Protocol version is now checked when the agent connects to the converge server.

Next up: sending connection metadata and username password from server to agent and sending environment information back to the server. This means then that the side channel will only be used for expiry time messages and session type with the client id passed in so the converge server can than correlate the results back to the correct channel.
2024-07-27 11:21:35 +02:00

349 lines
9.2 KiB
Go
Executable File

package main
import (
"bufio"
"converge/pkg/agent"
"converge/pkg/comms"
"converge/pkg/iowrappers"
"converge/pkg/terminal"
"converge/pkg/websocketutil"
"crypto/tls"
"fmt"
"github.com/gliderlabs/ssh"
"github.com/gorilla/websocket"
"github.com/pkg/sftp"
"io"
"log"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"regexp"
"runtime"
"strconv"
"strings"
"time"
_ "embed"
)
//go:embed hostkey.pem
var hostPrivateKey []byte
func SftpHandler(sess ssh.Session) {
uid := int(time.Now().UnixMilli())
agent.Login(uid, sess)
defer agent.LogOut(uid)
debugStream := io.Discard
serverOptions := []sftp.ServerOption{
sftp.WithDebug(debugStream),
}
server, err := sftp.NewServer(
sess,
serverOptions...,
)
if err != nil {
log.Printf("sftp tcpserver init error: %s\n", err)
return
}
if err := server.Serve(); err == io.EOF {
server.Close()
fmt.Println("sftp client exited session.")
} else if err != nil {
fmt.Println("sftp tcpserver completed with error:", err)
}
}
func sshServer(hostKeyFile string, shellCommand string,
passwordHandler ssh.PasswordHandler,
authorizedPublicKeys AuthorizedPublicKeys) *ssh.Server {
ssh.Handle(func(s ssh.Session) {
workingDirectory, _ := os.Getwd()
env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
process, err := terminal.PtySpawner.Start(s, env, shellCommand)
if err != nil {
panic(err)
}
uid := int(time.Now().UnixMilli())
agent.Login(uid, s)
iowrappers.SynchronizeStreams(process.Pipe(), s)
agent.LogOut(uid)
process.Wait()
process.Wait()
})
log.Println("starting ssh server, waiting for debug sessions")
server := ssh.Server{
PasswordHandler: passwordHandler,
PublicKeyHandler: authorizedPublicKeys.authorize,
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": SftpHandler,
},
}
option := ssh.HostKeyPEM(hostPrivateKey)
option(&server)
return &server
}
func echoServer(conn io.ReadWriter) {
log.Println("Echo service started")
io.Copy(conn, conn)
}
func netCatServer(conn io.ReadWriter) {
stdio := bufio.NewReadWriter(
bufio.NewReaderSize(os.Stdin, 0),
bufio.NewWriterSize(os.Stdout, 0))
iowrappers.SynchronizeStreams(conn, stdio)
}
type AgentService interface {
Run(listener net.Listener)
}
type ListenerServer func() *ssh.Server
func (server ListenerServer) Run(listener net.Listener) {
server().Serve(listener)
}
type ConnectionServer func(conn io.ReadWriter)
func (server ConnectionServer) Run(listener net.Listener) {
for {
conn, err := listener.Accept()
if err != nil {
panic(err)
}
go server(conn)
}
}
type ReaderFunc func(p []byte) (n int, err error)
func (f ReaderFunc) Read(p []byte) (n int, err error) {
return f(p)
}
func validateString(value, description, pattern string) {
matched, err := regexp.MatchString(pattern, value)
if err != nil || !matched {
printHelp(fmt.Sprintf("%s: wrong value '%s', must conform to pattern '%s'",
description, value, pattern))
}
}
func getId(id string) string {
if id == "" {
// not specified
return strconv.Itoa(time.Now().Nanosecond() % 1000000000)
}
validateString(id, "id", `^[a-zA-Z0-9][a-zA-Z0-9-]+$`)
return id
}
func printHelp(msg string) {
if msg != "" {
fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", msg)
}
helpText := "agent [options] <wsUrl> \n" +
"\n" +
"Run agent with <wsUrl> of the form ws[s]://<host>[:port]\n" +
"Here <ID> is the unique id of the agent that allows rendez-vous with an end-user.\n" +
"The end-user must specify the same id when connecting using ssh.\n" +
"\n" +
"--id: rendez-vous id. When specified an SSH authorized key must be used and password\n" +
" based access is disabled. When not specified a random id is chosen by the agent and\n" +
" password based access is possible. The password is configured on the converge server\n" +
"--ssh-keys-file: SSH authorized keys file in openssh format. By default .authorized_keys in the\n" +
" directory where the agent is started is used.\n" +
"--warning-time: advance warning time before sessio ends\n" +
"--expiry-time: expiry time of the session\n" +
"--check-interval: interval at which expiry is checked\n" +
"-insecure: allow invalid certificates\n"
fmt.Fprintln(os.Stderr, helpText)
os.Exit(1)
}
func getArg(args []string) (value string, ret []string) {
if len(args) < 2 {
printHelp(fmt.Sprintf("The '%s' option expects an argument", args[0]))
}
return args[1], args[1:]
}
func parseDuration(args []string, val string) (time.Duration, []string) {
duration, err := time.ParseDuration(val)
if err != nil {
printHelp(fmt.Sprintf("Error parsing duration: %v\n", err))
}
return duration, args[1:]
}
func main() {
id := ""
authorizedKeysFile := ".authorized_keys"
advanceWarningTime := 5 * time.Minute
agentExpriryTime := 10 * time.Minute
tickerInterval := 60 * time.Second
insecure := false
args := os.Args[1:]
for len(args) > 0 && strings.HasPrefix(args[0], "-") {
val := ""
switch args[0] {
case "--id":
id, args = getArg(args)
case "--ssh-keys-file":
authorizedKeysFile, args = getArg(args)
case "--warning-time":
advanceWarningTime, args = parseDuration(args, val)
case "--expiry-time":
agentExpriryTime, args = parseDuration(args, val)
case "--check-interval":
tickerInterval, args = parseDuration(args, val)
case "--insecure":
insecure = true
default:
printHelp("Unknown option " + args[0])
}
args = args[1:]
}
id = getId(id)
if len(args) != 1 {
printHelp("")
}
wsURL := args[0]
url, err := url.Parse(wsURL)
if err != nil {
printHelp(fmt.Sprintf("Invalid URL %s", wsURL))
}
if url.Path != "" && url.Path != "/" {
printHelp(fmt.Sprintf("Only a base URL without path may be specified: %s", wsURL))
}
wsURL += "/agent/" + id
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
}
if insecure {
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
conn, _, err := dialer.Dial(wsURL, nil)
if err != nil {
log.Println("WebSocket connection error:", err)
return
}
wsConn := websocketutil.NewWebSocketConn(conn)
defer wsConn.Close()
err = comms.CheckProtocolVersion(comms.Agent, wsConn)
if err != nil {
os.Exit(1)
}
commChannel, err := comms.NewCommChannel(comms.Agent, wsConn)
if err != nil {
panic(err)
}
// Authentiocation
sshUserCredentials, passwordHandler, authorizedKeys := setupAuthentication(commChannel, authorizedKeysFile)
// Choose shell
shell := chooseShell()
var service AgentService
service = ListenerServer(func() *ssh.Server {
return sshServer("hostkey.pem", shell, passwordHandler, authorizedKeys)
})
//service = ConnectionServer(netCatServer)
//service = ConnectionServer(echoServer)
log.Println()
log.Printf("Clients should use the following commands to connect to this agent:")
log.Println()
clientUrl := strings.ReplaceAll(wsURL, "/agent/", "/client/")
sshCommand := fmt.Sprintf("ssh -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost",
clientUrl, sshUserCredentials.Username)
sftpCommand := fmt.Sprintf("sftp -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost",
clientUrl, sshUserCredentials.Username)
log.Println(" # For SSH")
log.Println(" " + sshCommand)
log.Println()
log.Println(" # for SFTP")
log.Println(" " + sftpCommand)
log.Println()
urlObject, _ := url.Parse(wsURL)
log.Printf("wsproxy can be downloaded from %s",
strings.ReplaceAll(urlObject.Scheme, "ws", "http")+
"://"+urlObject.Host+"/docs/wsproxy")
log.Println()
agent.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval)
listener := comms.NewAgentListener(commChannel.Session)
service.Run(listener)
}
func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) {
// Random user name and password so that effectively no one can login
// until the user and password have been received from the server.
sshUserCredentials := comms.UserPassword{
Username: strconv.Itoa(rand.Int()),
Password: strconv.Itoa(rand.Int()),
}
passwordHandler := func(ctx ssh.Context, password string) bool {
// Replace with your own logic to validate username and password
return ctx.User() == sshUserCredentials.Username && password == sshUserCredentials.Password
}
go comms.ListenForServerEvents(commChannel, func(user comms.UserPassword) {
log.Println("Username and password configuration received from server")
sshUserCredentials = user
})
authorizedKeys := ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile)
if len(authorizedKeys.keys) > 0 {
log.Printf("A total of %d authorized ssh keys were found", len(authorizedKeys.keys))
}
return sshUserCredentials, passwordHandler, authorizedKeys
}
func chooseShell() string {
var err error
shells := []string{"bash", "sh", "ash", "ksh", "zsh", "fish", "tcsh", "csh"}
if runtime.GOOS == "windows" {
shells = []string{"powershell", "bash"}
}
shell := ""
for _, candidate := range shells {
shell, err = exec.LookPath(candidate)
if err == nil {
break
}
}
if shell == "" {
log.Printf("Cannot find a shell in %v", shells)
os.Exit(1)
}
log.Printf("Using shell %s for remote sessions", shell)
return shell
}