From ddc3b24ebf293f28d15d72167f858ad109efb013 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Wed, 24 Jul 2024 19:23:51 +0200 Subject: [PATCH] clean solution for concurrence in session.go by serializing all external calls (apart from initialization) through a channel. --- pkg/agent/help.txt | 9 +---- pkg/agent/session.go | 93 +++++++++++++++++++++++++++++--------------- pkg/async/async.go | 18 +++++++++ 3 files changed, 81 insertions(+), 39 deletions(-) create mode 100644 pkg/async/async.go diff --git a/pkg/agent/help.txt b/pkg/agent/help.txt index bf90031..6b3b99c 100644 --- a/pkg/agent/help.txt +++ b/pkg/agent/help.txt @@ -1,10 +1,5 @@ -Session is set to expire at %v -The session expires automatically after %v. -If there are no more sessions after logging out, the agent -terminates. - -You can extend this time using +You can extend expiry of the session using {{ if eq .os "windows" -}} echo > %agentdir%\.hold @@ -13,7 +8,7 @@ You can extend this time using {{- end }} The expiry time is equal to the modification time of the .hold -file with the expiry duration added. +file with the expiry duration (%v) added. To prevent the agent from exiting after the last session is gone, also use the above command in any shell. diff --git a/pkg/agent/session.go b/pkg/agent/session.go index 5339e51..ef2fd5c 100644 --- a/pkg/agent/session.go +++ b/pkg/agent/session.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "converge/pkg/async" "fmt" "github.com/fsnotify/fsnotify" "github.com/gliderlabs/ssh" @@ -38,6 +39,9 @@ type AgentState struct { // session expiry time agentExpriryTime time.Duration + // Last expiry time reported to the user. + lastExpiryTimmeReported time.Time + // ticker tickerInterval time.Duration ticker *time.Ticker @@ -63,6 +67,11 @@ const holdFilename = ".hold" var helpMessageTemplate string var helpMessage = formatHelpMessage() +// Events channel for asynchronous events. +var events = make(chan func(), 10) + +// External interface, asynchronous, apart from the initialization. + func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) { if fileExists(holdFilename) { log.Printf("Removing hold file '%s'", holdFilename) @@ -72,24 +81,42 @@ func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Dur } } state = AgentState{ - startTime: time.Now(), - advanceWarningTime: advanceWarningTime, - agentExpriryTime: agentExpiryTime, - tickerInterval: tickerInterval, - ticker: time.NewTicker(tickerInterval), - sessions: make(map[int]*AgentSession), - agentUsed: false, + startTime: time.Now(), + advanceWarningTime: advanceWarningTime, + agentExpriryTime: agentExpiryTime, + lastExpiryTimmeReported: time.Time{}, + tickerInterval: tickerInterval, + ticker: time.NewTicker(tickerInterval), + sessions: make(map[int]*AgentSession), + agentUsed: false, } go func() { for { <-state.ticker.C - check() + events <- async.Async(check) } }() go monitorHoldFile() + go func() { + for { + event := <-events + event() + } + }() + } +func Login(sessionId int, sshSession ssh.Session) { + events <- async.Async(login, sessionId, sshSession) +} + +func LogOut(sessionId int) { + events <- async.Async(logOut, sessionId) +} + +// Internal interface synchronous + func monitorHoldFile() { watcher, err := fsnotify.NewWatcher() if err != nil { @@ -100,7 +127,6 @@ func monitorHoldFile() { if err != nil { log.Printf("Cannot watch old file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err) } - expiryTime := state.expiryTime(holdFilename) for { select { case event, ok := <-watcher.Events: @@ -109,14 +135,7 @@ func monitorHoldFile() { } base := filepath.Base(event.Name) if base == holdFilename { - newExpiryTIme := state.expiryTime(holdFilename) - if newExpiryTIme != expiryTime { - message := fmt.Sprintf("Expiry time of session is now %s\n", - newExpiryTIme.Format(time.DateTime)) - message += holdFileMessage() - messageUsers(message) - expiryTime = newExpiryTIme - } + events <- async.Async(holdFileChange) } case err, ok := <-watcher.Errors: @@ -142,7 +161,7 @@ func holdFileMessage() string { return message } -func Login(sessionId int, sshSession ssh.Session) { +func login(sessionId int, sshSession ssh.Session) { log.Println("New login") hostname, _ := os.Hostname() @@ -160,21 +179,21 @@ func Login(sessionId int, sshSession ssh.Session) { } } - PrintMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname)) - PrintHelpMessage(sshSession) - agentSession := AgentSession{ startTime: time.Now(), sshSession: sshSession, } state.sessions[sessionId] = &agentSession state.agentUsed = true - LogStatus() + logStatus() + + printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname)) + holdFileChange() + printHelpMessage(sshSession) } -func PrintHelpMessage(sshSession ssh.Session) { - PrintMessage(sshSession, fmt.Sprintf(helpMessage, - state.expiryTime(holdFilename).Format(time.DateTime), +func printHelpMessage(sshSession ssh.Session) { + printMessage(sshSession, fmt.Sprintf(helpMessage, state.agentExpriryTime)) } @@ -194,21 +213,21 @@ func formatHelpMessage() string { return helpFormatted } -func LogOut(sessionId int) { +func logOut(sessionId int) { log.Println("User logged out") delete(state.sessions, sessionId) - LogStatus() + logStatus() check() } -func PrintMessage(sshSession ssh.Session, message string) { +func printMessage(sshSession ssh.Session, message string) { for _, line := range strings.Split(message, "\n") { io.WriteString(sshSession.Stderr(), "### "+line+"\n\r") } io.WriteString(sshSession.Stderr(), "\n\r") } -func LogStatus() { +func logStatus() { fmt := "%-20s %-20s %-20s" log.Println() log.Printf(fmt, "UID", "START_TIME", "TYPE") @@ -242,6 +261,17 @@ func (state *AgentState) expiryTime(filename string) time.Time { return stats.ModTime().Add(state.agentExpriryTime) } +func holdFileChange() { + newExpiryTIme := state.expiryTime(holdFilename) + if newExpiryTIme != state.lastExpiryTimmeReported { + message := fmt.Sprintf("Expiry time of session is now %s\n", + newExpiryTIme.Format(time.DateTime)) + message += holdFileMessage() + messageUsers(message) + state.lastExpiryTimmeReported = newExpiryTIme + } +} + // Behavior to implement // 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime // 2. The expiry time is relative to the modification time of the .hold file in the @@ -251,7 +281,6 @@ func (state *AgentState) expiryTime(filename string) time.Time { // 4. If the last user logs out, the aagent will exit immediately if no .hold file is // present. Otherwise it will exit after the epxiry time. This allows users to // reconnect later. - func check() { now := time.Now() @@ -268,7 +297,7 @@ func check() { messageUsers( fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime))) for _, session := range state.sessions { - PrintHelpMessage(session.sshSession) + printHelpMessage(session.sshSession) } } @@ -281,6 +310,6 @@ func check() { func messageUsers(message string) { log.Printf("=== Notification to users: %s", message) for _, session := range state.sessions { - PrintMessage(session.sshSession, message) + printMessage(session.sshSession, message) } } diff --git a/pkg/async/async.go b/pkg/async/async.go new file mode 100644 index 0000000..5c0a1db --- /dev/null +++ b/pkg/async/async.go @@ -0,0 +1,18 @@ +package async + +import "reflect" + +func Async(fn interface{}, args ...interface{}) func() { + fnValue := reflect.ValueOf(fn) + + // Prepare the arguments + params := make([]reflect.Value, len(args)) + for i, arg := range args { + params[i] = reflect.ValueOf(arg) + } + + // Return a function that, when called, will invoke the original function + return func() { + fnValue.Call(params) + } +}