converge/pkg/agent/service/sshservice.go

129 lines
3.5 KiB
Go

package service
import (
"fmt"
"git.wamblee.org/converge/pkg/agent/session"
"git.wamblee.org/converge/pkg/agent/terminal"
"git.wamblee.org/converge/pkg/comms"
"git.wamblee.org/converge/pkg/support/ioutils"
"github.com/gliderlabs/ssh"
"github.com/pkg/sftp"
"io"
"log"
"net"
"os"
"strings"
)
type SshAgentService struct {
hostPrivateKey []byte
shellCommand string
authorizedKeys *AuthorizedPublicKeys
}
func NewSshAgentService(hostPrivateKey []byte, shellCommand string,
authorizedKeys *AuthorizedPublicKeys) SshAgentService {
return SshAgentService{
hostPrivateKey: hostPrivateKey,
shellCommand: shellCommand,
authorizedKeys: authorizedKeys,
}
}
func (s SshAgentService) Run(listener net.Listener) {
s.sshServer().Serve(listener)
}
type SshAgentSession struct {
session ssh.Session
}
func NewSshAgentSession(session ssh.Session) SshAgentSession {
return SshAgentSession{session: session}
}
func (s SshAgentSession) MessageUser(message string) {
for _, line := range strings.Split(message, "\n") {
io.WriteString(s.session.Stderr(), "### "+line+"\n\r")
}
io.WriteString(s.session.Stderr(), "\n\r")
}
func (s SshAgentSession) Type() string {
sessionType := s.session.Subsystem()
if sessionType == "" {
sessionType = "ssh"
}
return sessionType
}
func (s *SshAgentService) sshServer() *ssh.Server {
ssh.Handle(func(sshSession ssh.Session) {
workingDirectory, _ := os.Getwd()
env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
process, err := terminal.PtySpawner.Start(sshSession, env, s.shellCommand)
if err != nil {
panic(err)
}
sessionInfo := comms.NewSessionInfo(
sshSession.LocalAddr().String(), "ssh",
)
session.Login(sessionInfo, NewSshAgentSession(sshSession))
// For SSH we detect activity when there are writes to the shell that was spawned.
// This means user input.
activityDetector := NewWriteDetector(process.Pipe())
ioutils.SynchronizeStreams("shell -- ssh", activityDetector, sshSession)
session.LogOut(sessionInfo.ClientId)
// Using Kill() here will create defunct processes and in normal
// circumstances Wait() will be the best because the process will be shutting
// down automatically becuase it has lost its terminal.
process.Wait()
})
log.Println("starting ssh server, waiting for debug sessions")
server := ssh.Server{
PublicKeyHandler: s.authorizedKeys.authorize,
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": SftpHandler,
},
}
option := ssh.HostKeyPEM(s.hostPrivateKey)
option(&server)
return &server
}
func SftpHandler(sftpSession ssh.Session) {
sessionInfo := comms.NewSessionInfo(
sftpSession.LocalAddr().String(),
"sftp",
)
session.Login(sessionInfo, NewSshAgentSession(sftpSession))
defer session.LogOut(sessionInfo.ClientId)
debugStream := io.Discard
serverOptions := []sftp.ServerOption{
sftp.WithDebug(debugStream),
}
// activity for sftp means that the server is sending data to the client.
// In contrast to activity detection for ssh we use activity detection based on
// serveractivity. This approach ensures that long downloads are not interrupted
// by a timeout and are allowed to finish.
activityDetector := NewSftpActivityDetector(sftpSession)
server, err := sftp.NewServer(
activityDetector,
serverOptions...,
)
if err != nil {
log.Printf("sftp tcpserver init error: %s\n", err)
return
}
if err := server.Serve(); err == io.EOF {
server.Close()
fmt.Println("sftp client exited session.")
} else if err != nil {
fmt.Println("sftp tcpserver completed with error:", err)
}
}