diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index f10666b..cf9f25c 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -5,13 +5,11 @@ import ( "crypto/tls" "fmt" "git.wamblee.org/converge/pkg/agent/session" - "git.wamblee.org/converge/pkg/agent/terminal" "git.wamblee.org/converge/pkg/comms" "git.wamblee.org/converge/pkg/support/iowrappers" "git.wamblee.org/converge/pkg/support/websocketutil" "github.com/gliderlabs/ssh" "github.com/gorilla/websocket" - "github.com/pkg/sftp" "io" "log" "net" @@ -29,79 +27,6 @@ import ( _ "net/http/pprof" ) -func SftpHandler(sftpSession ssh.Session) { - sessionInfo := comms.NewSessionInfo( - sftpSession.LocalAddr().String(), - "sftp", - ) - session.Login(sessionInfo, sftpSession) - defer session.LogOut(sessionInfo.ClientId) - - debugStream := io.Discard - serverOptions := []sftp.ServerOption{ - sftp.WithDebug(debugStream), - } - // activity for sftp means that the server is sending data to the client. - // In contrast to activity detection for ssh we use activity detection based on - // serveractivity. This approach ensures that long downloads are not interrupted - // by a timeout and are allowed to finish. - activityDetector := NewSftpActivityDetector(sftpSession) - server, err := sftp.NewServer( - activityDetector, - serverOptions..., - ) - if err != nil { - log.Printf("sftp tcpserver init error: %s\n", err) - return - } - if err := server.Serve(); err == io.EOF { - server.Close() - fmt.Println("sftp client exited session.") - } else if err != nil { - fmt.Println("sftp tcpserver completed with error:", err) - } -} - -func sshServer(hostPrivateKey []byte, shellCommand string, - authorizedPublicKeys *AuthorizedPublicKeys) *ssh.Server { - ssh.Handle(func(sshSession ssh.Session) { - workingDirectory, _ := os.Getwd() - env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory)) - process, err := terminal.PtySpawner.Start(sshSession, env, shellCommand) - if err != nil { - panic(err) - } - sessionInfo := comms.NewSessionInfo( - sshSession.LocalAddr().String(), "ssh", - ) - session.Login(sessionInfo, sshSession) - // For SSH we detect activity when there are writes to the shell that was spanwedn. - // This means user input. - activityDetector := NewWriteDetector(process.Pipe()) - iowrappers.SynchronizeStreams("shell -- ssh", activityDetector, sshSession) - session.LogOut(sessionInfo.ClientId) - // will cause addition goroutines to remmain alive when the SSH - // session is killed. For now acceptable since the agent is a short-lived - // process. Using Kill() here will create defunct processes and in normal - // circummstances Wait() will be the best because the process will be shutting - // down automatically becuase it has lost its terminal. - process.Wait() - }) - - log.Println("starting ssh server, waiting for debug sessions") - - server := ssh.Server{ - PublicKeyHandler: authorizedPublicKeys.authorize, - SubsystemHandlers: map[string]ssh.SubsystemHandler{ - "sftp": SftpHandler, - }, - } - option := ssh.HostKeyPEM(hostPrivateKey) - option(&server) - - return &server -} - func echoServer(conn io.ReadWriter) { log.Println("Echo service started") io.Copy(conn, conn) @@ -330,9 +255,11 @@ func main() { var service AgentService - service = ListenerServer(func() *ssh.Server { - return sshServer(registration.HostPrivateKey, shell, authorizedKeys) - }) + service = SshAgentService{ + hostPrivateKey: registration.HostPrivateKey, + shellCommand: shell, + authorizedKeys: authorizedKeys, + } //service = ConnectionServer(netCatServer) //service = ConnectionServer(echoServer) log.Println() diff --git a/cmd/agent/sshservice.go b/cmd/agent/sshservice.go new file mode 100644 index 0000000..1849068 --- /dev/null +++ b/cmd/agent/sshservice.go @@ -0,0 +1,119 @@ +package main + +import ( + "fmt" + "git.wamblee.org/converge/pkg/agent/session" + "git.wamblee.org/converge/pkg/agent/terminal" + "git.wamblee.org/converge/pkg/comms" + "git.wamblee.org/converge/pkg/support/iowrappers" + "github.com/gliderlabs/ssh" + "github.com/pkg/sftp" + "io" + "log" + "net" + "os" + "strings" +) + +type SshAgentService struct { + hostPrivateKey []byte + shellCommand string + authorizedKeys *AuthorizedPublicKeys +} + +func (s SshAgentService) Run(listener net.Listener) { + s.sshServer().Serve(listener) +} + +type SshAgentSession struct { + session ssh.Session +} + +func NewSshAgentSession(session ssh.Session) SshAgentSession { + return SshAgentSession{session: session} +} + +func (s SshAgentSession) MessageUser(message string) { + for _, line := range strings.Split(message, "\n") { + io.WriteString(s.session.Stderr(), "### "+line+"\n\r") + } + io.WriteString(s.session.Stderr(), "\n\r") +} + +func (s SshAgentSession) Type() string { + sessionType := s.session.Subsystem() + if sessionType == "" { + sessionType = "ssh" + } + return sessionType +} + +func (s *SshAgentService) sshServer() *ssh.Server { + ssh.Handle(func(sshSession ssh.Session) { + workingDirectory, _ := os.Getwd() + env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory)) + process, err := terminal.PtySpawner.Start(sshSession, env, s.shellCommand) + if err != nil { + panic(err) + } + sessionInfo := comms.NewSessionInfo( + sshSession.LocalAddr().String(), "ssh", + ) + session.Login(sessionInfo, NewSshAgentSession(sshSession)) + // For SSH we detect activity when there are writes to the shell that was spawned. + // This means user input. + activityDetector := NewWriteDetector(process.Pipe()) + iowrappers.SynchronizeStreams("shell -- ssh", activityDetector, sshSession) + session.LogOut(sessionInfo.ClientId) + // Using Kill() here will create defunct processes and in normal + // circumstances Wait() will be the best because the process will be shutting + // down automatically becuase it has lost its terminal. + process.Wait() + }) + + log.Println("starting ssh server, waiting for debug sessions") + + server := ssh.Server{ + PublicKeyHandler: s.authorizedKeys.authorize, + SubsystemHandlers: map[string]ssh.SubsystemHandler{ + "sftp": SftpHandler, + }, + } + option := ssh.HostKeyPEM(s.hostPrivateKey) + option(&server) + + return &server +} + +func SftpHandler(sftpSession ssh.Session) { + sessionInfo := comms.NewSessionInfo( + sftpSession.LocalAddr().String(), + "sftp", + ) + session.Login(sessionInfo, NewSshAgentSession(sftpSession)) + defer session.LogOut(sessionInfo.ClientId) + + debugStream := io.Discard + serverOptions := []sftp.ServerOption{ + sftp.WithDebug(debugStream), + } + // activity for sftp means that the server is sending data to the client. + // In contrast to activity detection for ssh we use activity detection based on + // serveractivity. This approach ensures that long downloads are not interrupted + // by a timeout and are allowed to finish. + activityDetector := NewSftpActivityDetector(sftpSession) + server, err := sftp.NewServer( + activityDetector, + serverOptions..., + ) + if err != nil { + log.Printf("sftp tcpserver init error: %s\n", err) + return + } + if err := server.Serve(); err == io.EOF { + server.Close() + fmt.Println("sftp client exited session.") + } else if err != nil { + fmt.Println("sftp tcpserver completed with error:", err) + } +} diff --git a/pkg/agent/service/service.go b/pkg/agent/service/service.go new file mode 100644 index 0000000..6d43c33 --- /dev/null +++ b/pkg/agent/service/service.go @@ -0,0 +1 @@ +package service diff --git a/pkg/agent/session/session.go b/pkg/agent/session/session.go index 89b745f..e815899 100644 --- a/pkg/agent/session/session.go +++ b/pkg/agent/session/session.go @@ -5,13 +5,10 @@ import ( "fmt" "git.wamblee.org/converge/pkg/comms" "github.com/fsnotify/fsnotify" - "github.com/gliderlabs/ssh" - "io" "log" "os" "path/filepath" "runtime" - "strings" "text/template" "time" @@ -28,6 +25,11 @@ import ( // global configuration +type UserSession interface { + MessageUser(message string) + Type() string +} + type AgentState struct { commChannel comms.CommChannel startTime time.Time @@ -57,7 +59,7 @@ type AgentSession struct { startTime time.Time // For sending messages to the user - sshSession ssh.Session + session UserSession } var state AgentState @@ -119,9 +121,9 @@ func ConfigureAgent(commChannel comms.CommChannel, } -func Login(sessionInfo comms.SessionInfo, sshSession ssh.Session) { +func Login(sessionInfo comms.SessionInfo, session UserSession) { events <- func() { - login(sessionInfo, sshSession) + login(sessionInfo, session) } } @@ -198,7 +200,7 @@ func holdFileMessage() string { return message } -func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) { +func login(sessionInfo comms.SessionInfo, session UserSession) { log.Println("New login") hostname, _ := os.Hostname() @@ -217,8 +219,8 @@ func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) { } agentSession := AgentSession{ - startTime: time.Now(), - sshSession: sshSession, + startTime: time.Now(), + session: session, } state.clients[sessionInfo.ClientId] = &agentSession state.lastUserActivityTime = time.Now() @@ -232,13 +234,13 @@ func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) { logStatus() - printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname)) + printMessage(session, fmt.Sprintf("You are now on %s\n", hostname)) holdFileChange() - printHelpMessage(sshSession) + printHelpMessage(session) } -func printHelpMessage(sshSession ssh.Session) { - printMessage(sshSession, fmt.Sprintf(helpMessage, +func printHelpMessage(session UserSession) { + printMessage(session, fmt.Sprintf(helpMessage, state.agentExpiryDuration)) } @@ -265,11 +267,8 @@ func logOut(clientId string) { check() } -func printMessage(sshSession ssh.Session, message string) { - for _, line := range strings.Split(message, "\n") { - io.WriteString(sshSession.Stderr(), "### "+line+"\n\r") - } - io.WriteString(sshSession.Stderr(), "\n\r") +func printMessage(session UserSession, message string) { + session.MessageUser(message) } func logStatus() { @@ -277,10 +276,7 @@ func logStatus() { log.Println() log.Printf(fmt, "CLIENT", "START_TIME", "TYPE") for uid, session := range state.clients { - sessionType := session.sshSession.Subsystem() - if sessionType == "" { - sessionType = "ssh" - } + sessionType := session.session.Type() log.Printf(fmt, uid, session.startTime.Format(time.DateTime), sessionType) @@ -359,7 +355,7 @@ func check() { messageUsers( fmt.Sprintf("Session will expire at %s, press any key to (ssh) or execute a command (sftp) to extend it.", expiryTime.Format(time.DateTime))) //for _, session := range state.clients { - // printHelpMessage(session.sshSession) + // printHelpMessage(session.session) //} } @@ -374,6 +370,6 @@ func check() { func messageUsers(message string) { log.Printf("=== Notification to users: %s", message) for _, session := range state.clients { - printMessage(session.sshSession, message) + printMessage(session.session, message) } }