simple session management solution with a .hold file and messages to the

user with better formatting.
This commit is contained in:
Erik Brakkee 2024-07-21 14:12:53 +02:00
parent 134c72d8d0
commit 22a3589d1d
4 changed files with 129 additions and 15 deletions

View File

@ -60,12 +60,14 @@ func setWinsize(f *os.File, w, h int) {
func sshServer(hostKeyFile string) *ssh.Server { func sshServer(hostKeyFile string) *ssh.Server {
ssh.Handle(func(s ssh.Session) { ssh.Handle(func(s ssh.Session) {
hostname, _ := os.Hostname() // TODO shell should be made configurable
io.WriteString(s, fmt.Sprintf("Your are now on %s\n\n", hostname))
cmd := exec.Command("bash") cmd := exec.Command("bash")
ptyReq, winCh, isPty := s.Pty() ptyReq, winCh, isPty := s.Pty()
if isPty { if isPty {
cmd.Env = append(os.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term)) workingDirectory, _ := os.Getwd()
cmd.Env = append(os.Environ(),
fmt.Sprintf("TERM=%s", ptyReq.Term),
fmt.Sprintf("agentdir=%s", workingDirectory))
f, err := pty.Start(cmd) f, err := pty.Start(cmd)
if err != nil { if err != nil {
panic(err) panic(err)
@ -83,7 +85,6 @@ func sshServer(hostKeyFile string) *ssh.Server {
}() }()
io.Copy(s, f) // stdout io.Copy(s, f) // stdout
cmd.Wait() cmd.Wait()
log.Println("User logged out")
agent.LogOut(uid) agent.LogOut(uid)
} else { } else {
io.WriteString(s, "No PTY requested.\n") io.WriteString(s, "No PTY requested.\n")
@ -154,9 +155,9 @@ func main() {
wsURL := os.Args[1] wsURL := os.Args[1]
advanceWarningTime := 1 * time.Minute advanceWarningTime := 1 * time.Minute
sessionExpiryTime := 5 * time.Minute agentExpriryTime := 2 * time.Minute
tickerInterval := 10 * time.Second tickerInterval := 10 * time.Second
agent.ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval) agent.ConfigureAgent(advanceWarningTime, agentExpriryTime, tickerInterval)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil { if err != nil {

13
pkg/agent/help.txt Normal file
View File

@ -0,0 +1,13 @@
Session is set to expire at %s
The session expires automatically after %d time.
If there are no more sessions after logging out, the agent
terminates.
You can extend this time using
touch $agentdir/.hold
To prevent the agent from exiting after the last sessioni is gone,
also use the above command in any shell.

View File

@ -1,22 +1,28 @@
package agent package agent
import ( import (
"fmt"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
"io" "io"
"log" "log"
"os"
"strconv" "strconv"
"strings"
"time" "time"
_ "embed"
) )
// global configuration // global configuration
type AgentState struct { type AgentState struct {
startTime time.Time
// Advance warning time to notify the user of something important happening // Advance warning time to notify the user of something important happening
advanceWarningTime time.Duration advanceWarningTime time.Duration
// session expiry time // session expiry time
sessionExpiryTime time.Duration agentExpriryTime time.Duration
// ticker // ticker
tickerInterval time.Duration tickerInterval time.Duration
@ -35,10 +41,23 @@ type AgentSession struct {
var state AgentState var state AgentState
func ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval time.Duration) { const holdFilename = ".hold"
//go:embed help.txt
var helpMessage string
func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) {
if fileExists(holdFilename) {
log.Printf("Removing hold file '%s'", holdFilename)
err := os.Remove(holdFilename)
if err != nil {
log.Printf("Could not remove hold file: '%s'", holdFilename)
}
}
state = AgentState{ state = AgentState{
startTime: time.Now(),
advanceWarningTime: advanceWarningTime, advanceWarningTime: advanceWarningTime,
sessionExpiryTime: sessionExpiryTime, agentExpriryTime: agentExpiryTime,
tickerInterval: tickerInterval, tickerInterval: tickerInterval,
ticker: time.NewTicker(tickerInterval), ticker: time.NewTicker(tickerInterval),
sessions: make(map[int]*AgentSession), sessions: make(map[int]*AgentSession),
@ -54,6 +73,25 @@ func ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval time.D
func Login(sessionId int, sshSession ssh.Session) { func Login(sessionId int, sshSession ssh.Session) {
log.Println("New login") log.Println("New login")
hostname, _ := os.Hostname()
holdFileStats, ok := fileExistsWithStats(holdFilename)
if ok {
if holdFileStats.ModTime().After(time.Now()) {
// modification time in the future, leaving intact
log.Println("Hold file has modification time in the future, leaving intact")
} else {
log.Printf("Touching hold file '%s'", holdFilename)
err := os.Chtimes(holdFilename, time.Now(), time.Now())
if err != nil {
log.Printf("Could not touch hold file: '%s'", holdFilename)
}
}
}
PrintMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
PrintHelpMessage(sshSession)
agentSession := AgentSession{ agentSession := AgentSession{
startTime: time.Now(), startTime: time.Now(),
sshSession: sshSession, sshSession: sshSession,
@ -62,6 +100,12 @@ func Login(sessionId int, sshSession ssh.Session) {
LogStatus() LogStatus()
} }
func PrintHelpMessage(sshSession ssh.Session) {
PrintMessage(sshSession, fmt.Sprintf(helpMessage,
state.expiryTime(holdFilename).Format(time.DateTime),
state.agentExpriryTime))
}
func LogOut(sessionId int) { func LogOut(sessionId int) {
log.Println("User logged out") log.Println("User logged out")
delete(state.sessions, sessionId) delete(state.sessions, sessionId)
@ -69,19 +113,75 @@ func LogOut(sessionId int) {
check() check()
} }
func PrintMessage(sshSession ssh.Session, message string) {
io.WriteString(sshSession.Stderr(), "\n\r###\n\r")
for _, line := range strings.Split(message, "\n") {
io.WriteString(sshSession.Stderr(), "### "+line+"\n\r")
}
io.WriteString(sshSession.Stderr(), "\n\r")
}
func LogStatus() { func LogStatus() {
fmt := "%-20s %-20s" fmt := "%-20s %-20s"
log.Println() log.Println()
log.Printf(fmt, "UID", "START_TIME") log.Printf(fmt, "UID", "START_TIME")
for uid, session := range state.sessions { for uid, session := range state.sessions {
log.Printf(fmt, strconv.Itoa(uid), session.startTime.Format("2006-01-02 15:04:05")) log.Printf(fmt, strconv.Itoa(uid), session.startTime.Format(time.DateTime))
} }
log.Println() log.Println()
} }
func fileExistsWithStats(filename string) (os.FileInfo, bool) {
stats, err := os.Stat(filename)
return stats, !os.IsNotExist(err)
}
func fileExists(filename string) bool {
_, err := os.Stat(filename)
return !os.IsNotExist(err)
}
func (state *AgentState) expiryTime(filename string) time.Time {
stats, err := os.Stat(filename)
if err != nil {
return state.startTime
}
return stats.ModTime()
}
// Behavior to implement
// 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime
// 2. The expiry time is relative to the modification time of the .hold file in the
// agent directory or, if that file does not exist, the start time of the agent.
// 3. if we are close to the expiry time then we message users with instruction on
// how to prevent the timeout
// 4. If the last user logs out, the aagent will exit immediately if no .hold file is
// present. Otherwise it will exit after the epxiry time. This allows users to
// reconnect later.
func check() { func check() {
log.Println("Timer is firing!") now := time.Now()
for _, session := range state.sessions {
io.WriteString(session.sshSession.Stderr(), "\n\nThe clock is ticking for you!\n\n") expiryTime := state.expiryTime(".hold").Add(state.agentExpriryTime)
if now.After(expiryTime) {
messageUsers("Expiry time was reached logging out")
time.Sleep(5 * time.Second)
log.Println("Agent exiting")
os.Exit(0)
}
if expiryTime.Sub(now) < state.advanceWarningTime {
messageUsers(
fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime)))
for _, session := range state.sessions {
PrintHelpMessage(session.sshSession)
}
}
}
func messageUsers(message string) {
log.Printf("=== Notification to users: %s", message)
for _, session := range state.sessions {
PrintMessage(session.sshSession, message)
} }
} }

View File

@ -65,14 +65,14 @@ func (admin *Admin) logStatus() {
for _, agent := range admin.agents { for _, agent := range admin.agents {
agent.clientSession.RemoteAddr() agent.clientSession.RemoteAddr()
log.Printf("%-20s %-20s %-20s\n", agent.publicId, log.Printf("%-20s %-20s %-20s\n", agent.publicId,
agent.startTime.Format("2006-01-02 15:04:05"), agent.startTime.Format(time.DateTime),
agent.clientSession.RemoteAddr().String()) agent.clientSession.RemoteAddr().String())
} }
log.Println("") log.Println("")
log.Printf("%-20s %-20s %-20s\n", "CLIENT", "ACTIVE_SINCE", "REMOTE_ADDRESS") log.Printf("%-20s %-20s %-20s\n", "CLIENT", "ACTIVE_SINCE", "REMOTE_ADDRESS")
for _, client := range admin.clients { for _, client := range admin.clients {
log.Printf("%-20s %-20s %-20s", client.publicId, log.Printf("%-20s %-20s %-20s", client.publicId,
client.startTime.Format("2006-01-02 15:04:05"), client.startTime.Format(time.DateTime),
client.client.RemoteAddr()) client.client.RemoteAddr())
} }
log.Printf("\n") log.Printf("\n")