clean solution for concurrence in session.go by serializing all external calls (apart from initialization) through a channel.

This commit is contained in:
Erik Brakkee 2024-07-24 19:23:51 +02:00
parent bdedef12f0
commit 689c8e63b4
3 changed files with 81 additions and 39 deletions

View File

@ -1,10 +1,5 @@
Session is set to expire at %v
The session expires automatically after %v. You can extend expiry of the session using
If there are no more sessions after logging out, the agent
terminates.
You can extend this time using
{{ if eq .os "windows" -}} {{ if eq .os "windows" -}}
echo > %agentdir%\.hold echo > %agentdir%\.hold
@ -13,7 +8,7 @@ You can extend this time using
{{- end }} {{- end }}
The expiry time is equal to the modification time of the .hold The expiry time is equal to the modification time of the .hold
file with the expiry duration added. file with the expiry duration (%v) added.
To prevent the agent from exiting after the last session is gone, To prevent the agent from exiting after the last session is gone,
also use the above command in any shell. also use the above command in any shell.

View File

@ -2,6 +2,7 @@ package agent
import ( import (
"bytes" "bytes"
"converge/pkg/async"
"fmt" "fmt"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
@ -38,6 +39,9 @@ type AgentState struct {
// session expiry time // session expiry time
agentExpriryTime time.Duration agentExpriryTime time.Duration
// Last expiry time reported to the user.
lastExpiryTimmeReported time.Time
// ticker // ticker
tickerInterval time.Duration tickerInterval time.Duration
ticker *time.Ticker ticker *time.Ticker
@ -63,6 +67,11 @@ const holdFilename = ".hold"
var helpMessageTemplate string var helpMessageTemplate string
var helpMessage = formatHelpMessage() var helpMessage = formatHelpMessage()
// Events channel for asynchronous events.
var events = make(chan func(), 10)
// External interface, asynchronous, apart from the initialization.
func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) { func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) {
if fileExists(holdFilename) { if fileExists(holdFilename) {
log.Printf("Removing hold file '%s'", holdFilename) log.Printf("Removing hold file '%s'", holdFilename)
@ -75,6 +84,7 @@ func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Dur
startTime: time.Now(), startTime: time.Now(),
advanceWarningTime: advanceWarningTime, advanceWarningTime: advanceWarningTime,
agentExpriryTime: agentExpiryTime, agentExpriryTime: agentExpiryTime,
lastExpiryTimmeReported: time.Time{},
tickerInterval: tickerInterval, tickerInterval: tickerInterval,
ticker: time.NewTicker(tickerInterval), ticker: time.NewTicker(tickerInterval),
sessions: make(map[int]*AgentSession), sessions: make(map[int]*AgentSession),
@ -84,12 +94,29 @@ func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Dur
go func() { go func() {
for { for {
<-state.ticker.C <-state.ticker.C
check() events <- async.Async(check)
} }
}() }()
go monitorHoldFile() go monitorHoldFile()
go func() {
for {
event := <-events
event()
}
}()
} }
func Login(sessionId int, sshSession ssh.Session) {
events <- async.Async(login, sessionId, sshSession)
}
func LogOut(sessionId int) {
events <- async.Async(logOut, sessionId)
}
// Internal interface synchronous
func monitorHoldFile() { func monitorHoldFile() {
watcher, err := fsnotify.NewWatcher() watcher, err := fsnotify.NewWatcher()
if err != nil { if err != nil {
@ -100,7 +127,6 @@ func monitorHoldFile() {
if err != nil { if err != nil {
log.Printf("Cannot watch old file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err) log.Printf("Cannot watch old file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err)
} }
expiryTime := state.expiryTime(holdFilename)
for { for {
select { select {
case event, ok := <-watcher.Events: case event, ok := <-watcher.Events:
@ -109,14 +135,7 @@ func monitorHoldFile() {
} }
base := filepath.Base(event.Name) base := filepath.Base(event.Name)
if base == holdFilename { if base == holdFilename {
newExpiryTIme := state.expiryTime(holdFilename) events <- async.Async(holdFileChange)
if newExpiryTIme != expiryTime {
message := fmt.Sprintf("Expiry time of session is now %s\n",
newExpiryTIme.Format(time.DateTime))
message += holdFileMessage()
messageUsers(message)
expiryTime = newExpiryTIme
}
} }
case err, ok := <-watcher.Errors: case err, ok := <-watcher.Errors:
@ -142,7 +161,7 @@ func holdFileMessage() string {
return message return message
} }
func Login(sessionId int, sshSession ssh.Session) { func login(sessionId int, sshSession ssh.Session) {
log.Println("New login") log.Println("New login")
hostname, _ := os.Hostname() hostname, _ := os.Hostname()
@ -160,21 +179,21 @@ func Login(sessionId int, sshSession ssh.Session) {
} }
} }
PrintMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
PrintHelpMessage(sshSession)
agentSession := AgentSession{ agentSession := AgentSession{
startTime: time.Now(), startTime: time.Now(),
sshSession: sshSession, sshSession: sshSession,
} }
state.sessions[sessionId] = &agentSession state.sessions[sessionId] = &agentSession
state.agentUsed = true state.agentUsed = true
LogStatus() logStatus()
printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
holdFileChange()
printHelpMessage(sshSession)
} }
func PrintHelpMessage(sshSession ssh.Session) { func printHelpMessage(sshSession ssh.Session) {
PrintMessage(sshSession, fmt.Sprintf(helpMessage, printMessage(sshSession, fmt.Sprintf(helpMessage,
state.expiryTime(holdFilename).Format(time.DateTime),
state.agentExpriryTime)) state.agentExpriryTime))
} }
@ -194,21 +213,21 @@ func formatHelpMessage() string {
return helpFormatted return helpFormatted
} }
func LogOut(sessionId int) { func logOut(sessionId int) {
log.Println("User logged out") log.Println("User logged out")
delete(state.sessions, sessionId) delete(state.sessions, sessionId)
LogStatus() logStatus()
check() check()
} }
func PrintMessage(sshSession ssh.Session, message string) { func printMessage(sshSession ssh.Session, message string) {
for _, line := range strings.Split(message, "\n") { for _, line := range strings.Split(message, "\n") {
io.WriteString(sshSession.Stderr(), "### "+line+"\n\r") io.WriteString(sshSession.Stderr(), "### "+line+"\n\r")
} }
io.WriteString(sshSession.Stderr(), "\n\r") io.WriteString(sshSession.Stderr(), "\n\r")
} }
func LogStatus() { func logStatus() {
fmt := "%-20s %-20s %-20s" fmt := "%-20s %-20s %-20s"
log.Println() log.Println()
log.Printf(fmt, "UID", "START_TIME", "TYPE") log.Printf(fmt, "UID", "START_TIME", "TYPE")
@ -242,6 +261,17 @@ func (state *AgentState) expiryTime(filename string) time.Time {
return stats.ModTime().Add(state.agentExpriryTime) return stats.ModTime().Add(state.agentExpriryTime)
} }
func holdFileChange() {
newExpiryTIme := state.expiryTime(holdFilename)
if newExpiryTIme != state.lastExpiryTimmeReported {
message := fmt.Sprintf("Expiry time of session is now %s\n",
newExpiryTIme.Format(time.DateTime))
message += holdFileMessage()
messageUsers(message)
state.lastExpiryTimmeReported = newExpiryTIme
}
}
// Behavior to implement // Behavior to implement
// 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime // 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime
// 2. The expiry time is relative to the modification time of the .hold file in the // 2. The expiry time is relative to the modification time of the .hold file in the
@ -251,7 +281,6 @@ func (state *AgentState) expiryTime(filename string) time.Time {
// 4. If the last user logs out, the aagent will exit immediately if no .hold file is // 4. If the last user logs out, the aagent will exit immediately if no .hold file is
// present. Otherwise it will exit after the epxiry time. This allows users to // present. Otherwise it will exit after the epxiry time. This allows users to
// reconnect later. // reconnect later.
func check() { func check() {
now := time.Now() now := time.Now()
@ -268,7 +297,7 @@ func check() {
messageUsers( messageUsers(
fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime))) fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime)))
for _, session := range state.sessions { for _, session := range state.sessions {
PrintHelpMessage(session.sshSession) printHelpMessage(session.sshSession)
} }
} }
@ -281,6 +310,6 @@ func check() {
func messageUsers(message string) { func messageUsers(message string) {
log.Printf("=== Notification to users: %s", message) log.Printf("=== Notification to users: %s", message)
for _, session := range state.sessions { for _, session := range state.sessions {
PrintMessage(session.sshSession, message) printMessage(session.sshSession, message)
} }
} }

18
pkg/async/async.go Normal file
View File

@ -0,0 +1,18 @@
package async
import "reflect"
func Async(fn interface{}, args ...interface{}) func() {
fnValue := reflect.ValueOf(fn)
// Prepare the arguments
params := make([]reflect.Value, len(args))
for i, arg := range args {
params[i] = reflect.ValueOf(arg)
}
// Return a function that, when called, will invoke the original function
return func() {
fnValue.Call(params)
}
}