From 85caa6cb5a0e85514c6e6bc4c15553719b13cdd5 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Sun, 21 Jul 2024 14:12:53 +0200 Subject: [PATCH] simple session management solution with a .hold file and messages to the user with better formatting. --- cmd/agent/agent.go | 13 ++--- pkg/agent/help.txt | 13 +++++ pkg/agent/session.go | 114 +++++++++++++++++++++++++++++++++++++++--- pkg/converge/admin.go | 4 +- 4 files changed, 129 insertions(+), 15 deletions(-) create mode 100644 pkg/agent/help.txt diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 6d663cc..38d2c10 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -60,12 +60,14 @@ func setWinsize(f *os.File, w, h int) { func sshServer(hostKeyFile string) *ssh.Server { ssh.Handle(func(s ssh.Session) { - hostname, _ := os.Hostname() - io.WriteString(s, fmt.Sprintf("Your are now on %s\n\n", hostname)) + // TODO shell should be made configurable cmd := exec.Command("bash") ptyReq, winCh, isPty := s.Pty() if isPty { - cmd.Env = append(os.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term)) + workingDirectory, _ := os.Getwd() + cmd.Env = append(os.Environ(), + fmt.Sprintf("TERM=%s", ptyReq.Term), + fmt.Sprintf("agentdir=%s", workingDirectory)) f, err := pty.Start(cmd) if err != nil { panic(err) @@ -83,7 +85,6 @@ func sshServer(hostKeyFile string) *ssh.Server { }() io.Copy(s, f) // stdout cmd.Wait() - log.Println("User logged out") agent.LogOut(uid) } else { io.WriteString(s, "No PTY requested.\n") @@ -154,9 +155,9 @@ func main() { wsURL := os.Args[1] advanceWarningTime := 1 * time.Minute - sessionExpiryTime := 5 * time.Minute + agentExpriryTime := 2 * time.Minute tickerInterval := 10 * time.Second - agent.ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval) + agent.ConfigureAgent(advanceWarningTime, agentExpriryTime, tickerInterval) conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { diff --git a/pkg/agent/help.txt b/pkg/agent/help.txt new file mode 100644 index 0000000..44435ef --- /dev/null +++ b/pkg/agent/help.txt @@ -0,0 +1,13 @@ +Session is set to expire at %s + +The session expires automatically after %d time. +If there are no more sessions after logging out, the agent +terminates. + +You can extend this time using + + touch $agentdir/.hold + +To prevent the agent from exiting after the last sessioni is gone, +also use the above command in any shell. + diff --git a/pkg/agent/session.go b/pkg/agent/session.go index 4511269..3c705c6 100644 --- a/pkg/agent/session.go +++ b/pkg/agent/session.go @@ -1,22 +1,28 @@ package agent import ( + "fmt" "github.com/gliderlabs/ssh" "io" "log" + "os" "strconv" + "strings" "time" + + _ "embed" ) // global configuration type AgentState struct { + startTime time.Time // Advance warning time to notify the user of something important happening advanceWarningTime time.Duration // session expiry time - sessionExpiryTime time.Duration + agentExpriryTime time.Duration // ticker tickerInterval time.Duration @@ -35,10 +41,23 @@ type AgentSession struct { var state AgentState -func ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval time.Duration) { +const holdFilename = ".hold" + +//go:embed help.txt +var helpMessage string + +func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) { + if fileExists(holdFilename) { + log.Printf("Removing hold file '%s'", holdFilename) + err := os.Remove(holdFilename) + if err != nil { + log.Printf("Could not remove hold file: '%s'", holdFilename) + } + } state = AgentState{ + startTime: time.Now(), advanceWarningTime: advanceWarningTime, - sessionExpiryTime: sessionExpiryTime, + agentExpriryTime: agentExpiryTime, tickerInterval: tickerInterval, ticker: time.NewTicker(tickerInterval), sessions: make(map[int]*AgentSession), @@ -54,6 +73,25 @@ func ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval time.D func Login(sessionId int, sshSession ssh.Session) { log.Println("New login") + hostname, _ := os.Hostname() + + holdFileStats, ok := fileExistsWithStats(holdFilename) + if ok { + if holdFileStats.ModTime().After(time.Now()) { + // modification time in the future, leaving intact + log.Println("Hold file has modification time in the future, leaving intact") + } else { + log.Printf("Touching hold file '%s'", holdFilename) + err := os.Chtimes(holdFilename, time.Now(), time.Now()) + if err != nil { + log.Printf("Could not touch hold file: '%s'", holdFilename) + } + } + } + + PrintMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname)) + PrintHelpMessage(sshSession) + agentSession := AgentSession{ startTime: time.Now(), sshSession: sshSession, @@ -62,6 +100,12 @@ func Login(sessionId int, sshSession ssh.Session) { LogStatus() } +func PrintHelpMessage(sshSession ssh.Session) { + PrintMessage(sshSession, fmt.Sprintf(helpMessage, + state.expiryTime(holdFilename).Format(time.DateTime), + state.agentExpriryTime)) +} + func LogOut(sessionId int) { log.Println("User logged out") delete(state.sessions, sessionId) @@ -69,19 +113,75 @@ func LogOut(sessionId int) { check() } +func PrintMessage(sshSession ssh.Session, message string) { + io.WriteString(sshSession.Stderr(), "\n\r###\n\r") + for _, line := range strings.Split(message, "\n") { + io.WriteString(sshSession.Stderr(), "### "+line+"\n\r") + } + io.WriteString(sshSession.Stderr(), "\n\r") +} + func LogStatus() { fmt := "%-20s %-20s" log.Println() log.Printf(fmt, "UID", "START_TIME") for uid, session := range state.sessions { - log.Printf(fmt, strconv.Itoa(uid), session.startTime.Format("2006-01-02 15:04:05")) + log.Printf(fmt, strconv.Itoa(uid), session.startTime.Format(time.DateTime)) } log.Println() } +func fileExistsWithStats(filename string) (os.FileInfo, bool) { + stats, err := os.Stat(filename) + return stats, !os.IsNotExist(err) +} + +func fileExists(filename string) bool { + _, err := os.Stat(filename) + return !os.IsNotExist(err) +} + +func (state *AgentState) expiryTime(filename string) time.Time { + stats, err := os.Stat(filename) + if err != nil { + return state.startTime + } + return stats.ModTime() +} + +// Behavior to implement +// 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime +// 2. The expiry time is relative to the modification time of the .hold file in the +// agent directory or, if that file does not exist, the start time of the agent. +// 3. if we are close to the expiry time then we message users with instruction on +// how to prevent the timeout +// 4. If the last user logs out, the aagent will exit immediately if no .hold file is +// present. Otherwise it will exit after the epxiry time. This allows users to +// reconnect later. func check() { - log.Println("Timer is firing!") - for _, session := range state.sessions { - io.WriteString(session.sshSession.Stderr(), "\n\nThe clock is ticking for you!\n\n") + now := time.Now() + + expiryTime := state.expiryTime(".hold").Add(state.agentExpriryTime) + + if now.After(expiryTime) { + messageUsers("Expiry time was reached logging out") + time.Sleep(5 * time.Second) + log.Println("Agent exiting") + os.Exit(0) + } + + if expiryTime.Sub(now) < state.advanceWarningTime { + messageUsers( + fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime))) + for _, session := range state.sessions { + PrintHelpMessage(session.sshSession) + } + } +} + +func messageUsers(message string) { + log.Printf("=== Notification to users: %s", message) + for _, session := range state.sessions { + PrintMessage(session.sshSession, message) } } diff --git a/pkg/converge/admin.go b/pkg/converge/admin.go index 5372d8d..cd80685 100644 --- a/pkg/converge/admin.go +++ b/pkg/converge/admin.go @@ -65,14 +65,14 @@ func (admin *Admin) logStatus() { for _, agent := range admin.agents { agent.clientSession.RemoteAddr() log.Printf("%-20s %-20s %-20s\n", agent.publicId, - agent.startTime.Format("2006-01-02 15:04:05"), + agent.startTime.Format(time.DateTime), agent.clientSession.RemoteAddr().String()) } log.Println("") log.Printf("%-20s %-20s %-20s\n", "CLIENT", "ACTIVE_SINCE", "REMOTE_ADDRESS") for _, client := range admin.clients { log.Printf("%-20s %-20s %-20s", client.publicId, - client.startTime.Format("2006-01-02 15:04:05"), + client.startTime.Format(time.DateTime), client.client.RemoteAddr()) } log.Printf("\n")