agent/session.go no longer depends on ssh.Session and uses an

internal interface.
This commit is contained in:
Erik Brakkee 2024-08-26 09:54:05 +02:00
parent 4f97b29776
commit 781c14dcf4
4 changed files with 145 additions and 102 deletions

View File

@ -5,13 +5,11 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"git.wamblee.org/converge/pkg/agent/session" "git.wamblee.org/converge/pkg/agent/session"
"git.wamblee.org/converge/pkg/agent/terminal"
"git.wamblee.org/converge/pkg/comms" "git.wamblee.org/converge/pkg/comms"
"git.wamblee.org/converge/pkg/support/iowrappers" "git.wamblee.org/converge/pkg/support/iowrappers"
"git.wamblee.org/converge/pkg/support/websocketutil" "git.wamblee.org/converge/pkg/support/websocketutil"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pkg/sftp"
"io" "io"
"log" "log"
"net" "net"
@ -29,79 +27,6 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
) )
func SftpHandler(sftpSession ssh.Session) {
sessionInfo := comms.NewSessionInfo(
sftpSession.LocalAddr().String(),
"sftp",
)
session.Login(sessionInfo, sftpSession)
defer session.LogOut(sessionInfo.ClientId)
debugStream := io.Discard
serverOptions := []sftp.ServerOption{
sftp.WithDebug(debugStream),
}
// activity for sftp means that the server is sending data to the client.
// In contrast to activity detection for ssh we use activity detection based on
// serveractivity. This approach ensures that long downloads are not interrupted
// by a timeout and are allowed to finish.
activityDetector := NewSftpActivityDetector(sftpSession)
server, err := sftp.NewServer(
activityDetector,
serverOptions...,
)
if err != nil {
log.Printf("sftp tcpserver init error: %s\n", err)
return
}
if err := server.Serve(); err == io.EOF {
server.Close()
fmt.Println("sftp client exited session.")
} else if err != nil {
fmt.Println("sftp tcpserver completed with error:", err)
}
}
func sshServer(hostPrivateKey []byte, shellCommand string,
authorizedPublicKeys *AuthorizedPublicKeys) *ssh.Server {
ssh.Handle(func(sshSession ssh.Session) {
workingDirectory, _ := os.Getwd()
env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
process, err := terminal.PtySpawner.Start(sshSession, env, shellCommand)
if err != nil {
panic(err)
}
sessionInfo := comms.NewSessionInfo(
sshSession.LocalAddr().String(), "ssh",
)
session.Login(sessionInfo, sshSession)
// For SSH we detect activity when there are writes to the shell that was spanwedn.
// This means user input.
activityDetector := NewWriteDetector(process.Pipe())
iowrappers.SynchronizeStreams("shell -- ssh", activityDetector, sshSession)
session.LogOut(sessionInfo.ClientId)
// will cause addition goroutines to remmain alive when the SSH
// session is killed. For now acceptable since the agent is a short-lived
// process. Using Kill() here will create defunct processes and in normal
// circummstances Wait() will be the best because the process will be shutting
// down automatically becuase it has lost its terminal.
process.Wait()
})
log.Println("starting ssh server, waiting for debug sessions")
server := ssh.Server{
PublicKeyHandler: authorizedPublicKeys.authorize,
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": SftpHandler,
},
}
option := ssh.HostKeyPEM(hostPrivateKey)
option(&server)
return &server
}
func echoServer(conn io.ReadWriter) { func echoServer(conn io.ReadWriter) {
log.Println("Echo service started") log.Println("Echo service started")
io.Copy(conn, conn) io.Copy(conn, conn)
@ -330,9 +255,11 @@ func main() {
var service AgentService var service AgentService
service = ListenerServer(func() *ssh.Server { service = SshAgentService{
return sshServer(registration.HostPrivateKey, shell, authorizedKeys) hostPrivateKey: registration.HostPrivateKey,
}) shellCommand: shell,
authorizedKeys: authorizedKeys,
}
//service = ConnectionServer(netCatServer) //service = ConnectionServer(netCatServer)
//service = ConnectionServer(echoServer) //service = ConnectionServer(echoServer)
log.Println() log.Println()

119
cmd/agent/sshservice.go Normal file
View File

@ -0,0 +1,119 @@
package main
import (
"fmt"
"git.wamblee.org/converge/pkg/agent/session"
"git.wamblee.org/converge/pkg/agent/terminal"
"git.wamblee.org/converge/pkg/comms"
"git.wamblee.org/converge/pkg/support/iowrappers"
"github.com/gliderlabs/ssh"
"github.com/pkg/sftp"
"io"
"log"
"net"
"os"
"strings"
)
type SshAgentService struct {
hostPrivateKey []byte
shellCommand string
authorizedKeys *AuthorizedPublicKeys
}
func (s SshAgentService) Run(listener net.Listener) {
s.sshServer().Serve(listener)
}
type SshAgentSession struct {
session ssh.Session
}
func NewSshAgentSession(session ssh.Session) SshAgentSession {
return SshAgentSession{session: session}
}
func (s SshAgentSession) MessageUser(message string) {
for _, line := range strings.Split(message, "\n") {
io.WriteString(s.session.Stderr(), "### "+line+"\n\r")
}
io.WriteString(s.session.Stderr(), "\n\r")
}
func (s SshAgentSession) Type() string {
sessionType := s.session.Subsystem()
if sessionType == "" {
sessionType = "ssh"
}
return sessionType
}
func (s *SshAgentService) sshServer() *ssh.Server {
ssh.Handle(func(sshSession ssh.Session) {
workingDirectory, _ := os.Getwd()
env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
process, err := terminal.PtySpawner.Start(sshSession, env, s.shellCommand)
if err != nil {
panic(err)
}
sessionInfo := comms.NewSessionInfo(
sshSession.LocalAddr().String(), "ssh",
)
session.Login(sessionInfo, NewSshAgentSession(sshSession))
// For SSH we detect activity when there are writes to the shell that was spawned.
// This means user input.
activityDetector := NewWriteDetector(process.Pipe())
iowrappers.SynchronizeStreams("shell -- ssh", activityDetector, sshSession)
session.LogOut(sessionInfo.ClientId)
// Using Kill() here will create defunct processes and in normal
// circumstances Wait() will be the best because the process will be shutting
// down automatically becuase it has lost its terminal.
process.Wait()
})
log.Println("starting ssh server, waiting for debug sessions")
server := ssh.Server{
PublicKeyHandler: s.authorizedKeys.authorize,
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": SftpHandler,
},
}
option := ssh.HostKeyPEM(s.hostPrivateKey)
option(&server)
return &server
}
func SftpHandler(sftpSession ssh.Session) {
sessionInfo := comms.NewSessionInfo(
sftpSession.LocalAddr().String(),
"sftp",
)
session.Login(sessionInfo, NewSshAgentSession(sftpSession))
defer session.LogOut(sessionInfo.ClientId)
debugStream := io.Discard
serverOptions := []sftp.ServerOption{
sftp.WithDebug(debugStream),
}
// activity for sftp means that the server is sending data to the client.
// In contrast to activity detection for ssh we use activity detection based on
// serveractivity. This approach ensures that long downloads are not interrupted
// by a timeout and are allowed to finish.
activityDetector := NewSftpActivityDetector(sftpSession)
server, err := sftp.NewServer(
activityDetector,
serverOptions...,
)
if err != nil {
log.Printf("sftp tcpserver init error: %s\n", err)
return
}
if err := server.Serve(); err == io.EOF {
server.Close()
fmt.Println("sftp client exited session.")
} else if err != nil {
fmt.Println("sftp tcpserver completed with error:", err)
}
}

View File

@ -0,0 +1 @@
package service

View File

@ -5,13 +5,10 @@ import (
"fmt" "fmt"
"git.wamblee.org/converge/pkg/comms" "git.wamblee.org/converge/pkg/comms"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/gliderlabs/ssh"
"io"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings"
"text/template" "text/template"
"time" "time"
@ -28,6 +25,11 @@ import (
// global configuration // global configuration
type UserSession interface {
MessageUser(message string)
Type() string
}
type AgentState struct { type AgentState struct {
commChannel comms.CommChannel commChannel comms.CommChannel
startTime time.Time startTime time.Time
@ -57,7 +59,7 @@ type AgentSession struct {
startTime time.Time startTime time.Time
// For sending messages to the user // For sending messages to the user
sshSession ssh.Session session UserSession
} }
var state AgentState var state AgentState
@ -119,9 +121,9 @@ func ConfigureAgent(commChannel comms.CommChannel,
} }
func Login(sessionInfo comms.SessionInfo, sshSession ssh.Session) { func Login(sessionInfo comms.SessionInfo, session UserSession) {
events <- func() { events <- func() {
login(sessionInfo, sshSession) login(sessionInfo, session)
} }
} }
@ -198,7 +200,7 @@ func holdFileMessage() string {
return message return message
} }
func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) { func login(sessionInfo comms.SessionInfo, session UserSession) {
log.Println("New login") log.Println("New login")
hostname, _ := os.Hostname() hostname, _ := os.Hostname()
@ -218,7 +220,7 @@ func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
agentSession := AgentSession{ agentSession := AgentSession{
startTime: time.Now(), startTime: time.Now(),
sshSession: sshSession, session: session,
} }
state.clients[sessionInfo.ClientId] = &agentSession state.clients[sessionInfo.ClientId] = &agentSession
state.lastUserActivityTime = time.Now() state.lastUserActivityTime = time.Now()
@ -232,13 +234,13 @@ func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
logStatus() logStatus()
printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname)) printMessage(session, fmt.Sprintf("You are now on %s\n", hostname))
holdFileChange() holdFileChange()
printHelpMessage(sshSession) printHelpMessage(session)
} }
func printHelpMessage(sshSession ssh.Session) { func printHelpMessage(session UserSession) {
printMessage(sshSession, fmt.Sprintf(helpMessage, printMessage(session, fmt.Sprintf(helpMessage,
state.agentExpiryDuration)) state.agentExpiryDuration))
} }
@ -265,11 +267,8 @@ func logOut(clientId string) {
check() check()
} }
func printMessage(sshSession ssh.Session, message string) { func printMessage(session UserSession, message string) {
for _, line := range strings.Split(message, "\n") { session.MessageUser(message)
io.WriteString(sshSession.Stderr(), "### "+line+"\n\r")
}
io.WriteString(sshSession.Stderr(), "\n\r")
} }
func logStatus() { func logStatus() {
@ -277,10 +276,7 @@ func logStatus() {
log.Println() log.Println()
log.Printf(fmt, "CLIENT", "START_TIME", "TYPE") log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
for uid, session := range state.clients { for uid, session := range state.clients {
sessionType := session.sshSession.Subsystem() sessionType := session.session.Type()
if sessionType == "" {
sessionType = "ssh"
}
log.Printf(fmt, uid, log.Printf(fmt, uid,
session.startTime.Format(time.DateTime), session.startTime.Format(time.DateTime),
sessionType) sessionType)
@ -359,7 +355,7 @@ func check() {
messageUsers( messageUsers(
fmt.Sprintf("Session will expire at %s, press any key to (ssh) or execute a command (sftp) to extend it.", expiryTime.Format(time.DateTime))) fmt.Sprintf("Session will expire at %s, press any key to (ssh) or execute a command (sftp) to extend it.", expiryTime.Format(time.DateTime)))
//for _, session := range state.clients { //for _, session := range state.clients {
// printHelpMessage(session.sshSession) // printHelpMessage(session.session)
//} //}
} }
@ -374,6 +370,6 @@ func check() {
func messageUsers(message string) { func messageUsers(message string) {
log.Printf("=== Notification to users: %s", message) log.Printf("=== Notification to users: %s", message)
for _, session := range state.clients { for _, session := range state.clients {
printMessage(session.sshSession, message) printMessage(session.session, message)
} }
} }