agent/session.go no longer depends on ssh.Session and uses an
internal interface.
This commit is contained in:
parent
4f97b29776
commit
781c14dcf4
@ -5,13 +5,11 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"git.wamblee.org/converge/pkg/agent/session"
|
||||
"git.wamblee.org/converge/pkg/agent/terminal"
|
||||
"git.wamblee.org/converge/pkg/comms"
|
||||
"git.wamblee.org/converge/pkg/support/iowrappers"
|
||||
"git.wamblee.org/converge/pkg/support/websocketutil"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/sftp"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
@ -29,79 +27,6 @@ import (
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
func SftpHandler(sftpSession ssh.Session) {
|
||||
sessionInfo := comms.NewSessionInfo(
|
||||
sftpSession.LocalAddr().String(),
|
||||
"sftp",
|
||||
)
|
||||
session.Login(sessionInfo, sftpSession)
|
||||
defer session.LogOut(sessionInfo.ClientId)
|
||||
|
||||
debugStream := io.Discard
|
||||
serverOptions := []sftp.ServerOption{
|
||||
sftp.WithDebug(debugStream),
|
||||
}
|
||||
// activity for sftp means that the server is sending data to the client.
|
||||
// In contrast to activity detection for ssh we use activity detection based on
|
||||
// serveractivity. This approach ensures that long downloads are not interrupted
|
||||
// by a timeout and are allowed to finish.
|
||||
activityDetector := NewSftpActivityDetector(sftpSession)
|
||||
server, err := sftp.NewServer(
|
||||
activityDetector,
|
||||
serverOptions...,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("sftp tcpserver init error: %s\n", err)
|
||||
return
|
||||
}
|
||||
if err := server.Serve(); err == io.EOF {
|
||||
server.Close()
|
||||
fmt.Println("sftp client exited session.")
|
||||
} else if err != nil {
|
||||
fmt.Println("sftp tcpserver completed with error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sshServer(hostPrivateKey []byte, shellCommand string,
|
||||
authorizedPublicKeys *AuthorizedPublicKeys) *ssh.Server {
|
||||
ssh.Handle(func(sshSession ssh.Session) {
|
||||
workingDirectory, _ := os.Getwd()
|
||||
env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
|
||||
process, err := terminal.PtySpawner.Start(sshSession, env, shellCommand)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sessionInfo := comms.NewSessionInfo(
|
||||
sshSession.LocalAddr().String(), "ssh",
|
||||
)
|
||||
session.Login(sessionInfo, sshSession)
|
||||
// For SSH we detect activity when there are writes to the shell that was spanwedn.
|
||||
// This means user input.
|
||||
activityDetector := NewWriteDetector(process.Pipe())
|
||||
iowrappers.SynchronizeStreams("shell -- ssh", activityDetector, sshSession)
|
||||
session.LogOut(sessionInfo.ClientId)
|
||||
// will cause addition goroutines to remmain alive when the SSH
|
||||
// session is killed. For now acceptable since the agent is a short-lived
|
||||
// process. Using Kill() here will create defunct processes and in normal
|
||||
// circummstances Wait() will be the best because the process will be shutting
|
||||
// down automatically becuase it has lost its terminal.
|
||||
process.Wait()
|
||||
})
|
||||
|
||||
log.Println("starting ssh server, waiting for debug sessions")
|
||||
|
||||
server := ssh.Server{
|
||||
PublicKeyHandler: authorizedPublicKeys.authorize,
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": SftpHandler,
|
||||
},
|
||||
}
|
||||
option := ssh.HostKeyPEM(hostPrivateKey)
|
||||
option(&server)
|
||||
|
||||
return &server
|
||||
}
|
||||
|
||||
func echoServer(conn io.ReadWriter) {
|
||||
log.Println("Echo service started")
|
||||
io.Copy(conn, conn)
|
||||
@ -330,9 +255,11 @@ func main() {
|
||||
|
||||
var service AgentService
|
||||
|
||||
service = ListenerServer(func() *ssh.Server {
|
||||
return sshServer(registration.HostPrivateKey, shell, authorizedKeys)
|
||||
})
|
||||
service = SshAgentService{
|
||||
hostPrivateKey: registration.HostPrivateKey,
|
||||
shellCommand: shell,
|
||||
authorizedKeys: authorizedKeys,
|
||||
}
|
||||
//service = ConnectionServer(netCatServer)
|
||||
//service = ConnectionServer(echoServer)
|
||||
log.Println()
|
||||
|
119
cmd/agent/sshservice.go
Normal file
119
cmd/agent/sshservice.go
Normal file
@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.wamblee.org/converge/pkg/agent/session"
|
||||
"git.wamblee.org/converge/pkg/agent/terminal"
|
||||
"git.wamblee.org/converge/pkg/comms"
|
||||
"git.wamblee.org/converge/pkg/support/iowrappers"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/sftp"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SshAgentService struct {
|
||||
hostPrivateKey []byte
|
||||
shellCommand string
|
||||
authorizedKeys *AuthorizedPublicKeys
|
||||
}
|
||||
|
||||
func (s SshAgentService) Run(listener net.Listener) {
|
||||
s.sshServer().Serve(listener)
|
||||
}
|
||||
|
||||
type SshAgentSession struct {
|
||||
session ssh.Session
|
||||
}
|
||||
|
||||
func NewSshAgentSession(session ssh.Session) SshAgentSession {
|
||||
return SshAgentSession{session: session}
|
||||
}
|
||||
|
||||
func (s SshAgentSession) MessageUser(message string) {
|
||||
for _, line := range strings.Split(message, "\n") {
|
||||
io.WriteString(s.session.Stderr(), "### "+line+"\n\r")
|
||||
}
|
||||
io.WriteString(s.session.Stderr(), "\n\r")
|
||||
}
|
||||
|
||||
func (s SshAgentSession) Type() string {
|
||||
sessionType := s.session.Subsystem()
|
||||
if sessionType == "" {
|
||||
sessionType = "ssh"
|
||||
}
|
||||
return sessionType
|
||||
}
|
||||
|
||||
func (s *SshAgentService) sshServer() *ssh.Server {
|
||||
ssh.Handle(func(sshSession ssh.Session) {
|
||||
workingDirectory, _ := os.Getwd()
|
||||
env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
|
||||
process, err := terminal.PtySpawner.Start(sshSession, env, s.shellCommand)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sessionInfo := comms.NewSessionInfo(
|
||||
sshSession.LocalAddr().String(), "ssh",
|
||||
)
|
||||
session.Login(sessionInfo, NewSshAgentSession(sshSession))
|
||||
// For SSH we detect activity when there are writes to the shell that was spawned.
|
||||
// This means user input.
|
||||
activityDetector := NewWriteDetector(process.Pipe())
|
||||
iowrappers.SynchronizeStreams("shell -- ssh", activityDetector, sshSession)
|
||||
session.LogOut(sessionInfo.ClientId)
|
||||
// Using Kill() here will create defunct processes and in normal
|
||||
// circumstances Wait() will be the best because the process will be shutting
|
||||
// down automatically becuase it has lost its terminal.
|
||||
process.Wait()
|
||||
})
|
||||
|
||||
log.Println("starting ssh server, waiting for debug sessions")
|
||||
|
||||
server := ssh.Server{
|
||||
PublicKeyHandler: s.authorizedKeys.authorize,
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": SftpHandler,
|
||||
},
|
||||
}
|
||||
option := ssh.HostKeyPEM(s.hostPrivateKey)
|
||||
option(&server)
|
||||
|
||||
return &server
|
||||
}
|
||||
|
||||
func SftpHandler(sftpSession ssh.Session) {
|
||||
sessionInfo := comms.NewSessionInfo(
|
||||
sftpSession.LocalAddr().String(),
|
||||
"sftp",
|
||||
)
|
||||
session.Login(sessionInfo, NewSshAgentSession(sftpSession))
|
||||
defer session.LogOut(sessionInfo.ClientId)
|
||||
|
||||
debugStream := io.Discard
|
||||
serverOptions := []sftp.ServerOption{
|
||||
sftp.WithDebug(debugStream),
|
||||
}
|
||||
// activity for sftp means that the server is sending data to the client.
|
||||
// In contrast to activity detection for ssh we use activity detection based on
|
||||
// serveractivity. This approach ensures that long downloads are not interrupted
|
||||
// by a timeout and are allowed to finish.
|
||||
activityDetector := NewSftpActivityDetector(sftpSession)
|
||||
server, err := sftp.NewServer(
|
||||
activityDetector,
|
||||
serverOptions...,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("sftp tcpserver init error: %s\n", err)
|
||||
return
|
||||
}
|
||||
if err := server.Serve(); err == io.EOF {
|
||||
server.Close()
|
||||
fmt.Println("sftp client exited session.")
|
||||
} else if err != nil {
|
||||
fmt.Println("sftp tcpserver completed with error:", err)
|
||||
}
|
||||
}
|
1
pkg/agent/service/service.go
Normal file
1
pkg/agent/service/service.go
Normal file
@ -0,0 +1 @@
|
||||
package service
|
@ -5,13 +5,10 @@ import (
|
||||
"fmt"
|
||||
"git.wamblee.org/converge/pkg/comms"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
@ -28,6 +25,11 @@ import (
|
||||
|
||||
// global configuration
|
||||
|
||||
type UserSession interface {
|
||||
MessageUser(message string)
|
||||
Type() string
|
||||
}
|
||||
|
||||
type AgentState struct {
|
||||
commChannel comms.CommChannel
|
||||
startTime time.Time
|
||||
@ -57,7 +59,7 @@ type AgentSession struct {
|
||||
startTime time.Time
|
||||
|
||||
// For sending messages to the user
|
||||
sshSession ssh.Session
|
||||
session UserSession
|
||||
}
|
||||
|
||||
var state AgentState
|
||||
@ -119,9 +121,9 @@ func ConfigureAgent(commChannel comms.CommChannel,
|
||||
|
||||
}
|
||||
|
||||
func Login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
|
||||
func Login(sessionInfo comms.SessionInfo, session UserSession) {
|
||||
events <- func() {
|
||||
login(sessionInfo, sshSession)
|
||||
login(sessionInfo, session)
|
||||
}
|
||||
}
|
||||
|
||||
@ -198,7 +200,7 @@ func holdFileMessage() string {
|
||||
return message
|
||||
}
|
||||
|
||||
func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
|
||||
func login(sessionInfo comms.SessionInfo, session UserSession) {
|
||||
log.Println("New login")
|
||||
hostname, _ := os.Hostname()
|
||||
|
||||
@ -217,8 +219,8 @@ func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
|
||||
}
|
||||
|
||||
agentSession := AgentSession{
|
||||
startTime: time.Now(),
|
||||
sshSession: sshSession,
|
||||
startTime: time.Now(),
|
||||
session: session,
|
||||
}
|
||||
state.clients[sessionInfo.ClientId] = &agentSession
|
||||
state.lastUserActivityTime = time.Now()
|
||||
@ -232,13 +234,13 @@ func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
|
||||
|
||||
logStatus()
|
||||
|
||||
printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
|
||||
printMessage(session, fmt.Sprintf("You are now on %s\n", hostname))
|
||||
holdFileChange()
|
||||
printHelpMessage(sshSession)
|
||||
printHelpMessage(session)
|
||||
}
|
||||
|
||||
func printHelpMessage(sshSession ssh.Session) {
|
||||
printMessage(sshSession, fmt.Sprintf(helpMessage,
|
||||
func printHelpMessage(session UserSession) {
|
||||
printMessage(session, fmt.Sprintf(helpMessage,
|
||||
state.agentExpiryDuration))
|
||||
}
|
||||
|
||||
@ -265,11 +267,8 @@ func logOut(clientId string) {
|
||||
check()
|
||||
}
|
||||
|
||||
func printMessage(sshSession ssh.Session, message string) {
|
||||
for _, line := range strings.Split(message, "\n") {
|
||||
io.WriteString(sshSession.Stderr(), "### "+line+"\n\r")
|
||||
}
|
||||
io.WriteString(sshSession.Stderr(), "\n\r")
|
||||
func printMessage(session UserSession, message string) {
|
||||
session.MessageUser(message)
|
||||
}
|
||||
|
||||
func logStatus() {
|
||||
@ -277,10 +276,7 @@ func logStatus() {
|
||||
log.Println()
|
||||
log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
|
||||
for uid, session := range state.clients {
|
||||
sessionType := session.sshSession.Subsystem()
|
||||
if sessionType == "" {
|
||||
sessionType = "ssh"
|
||||
}
|
||||
sessionType := session.session.Type()
|
||||
log.Printf(fmt, uid,
|
||||
session.startTime.Format(time.DateTime),
|
||||
sessionType)
|
||||
@ -359,7 +355,7 @@ func check() {
|
||||
messageUsers(
|
||||
fmt.Sprintf("Session will expire at %s, press any key to (ssh) or execute a command (sftp) to extend it.", expiryTime.Format(time.DateTime)))
|
||||
//for _, session := range state.clients {
|
||||
// printHelpMessage(session.sshSession)
|
||||
// printHelpMessage(session.session)
|
||||
//}
|
||||
}
|
||||
|
||||
@ -374,6 +370,6 @@ func check() {
|
||||
func messageUsers(message string) {
|
||||
log.Printf("=== Notification to users: %s", message)
|
||||
for _, session := range state.clients {
|
||||
printMessage(session.sshSession, message)
|
||||
printMessage(session.session, message)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user