clean solution for concurrence in session.go by serializing all external calls (apart from initialization) through a channel.

This commit is contained in:
Erik Brakkee 2024-07-24 19:23:51 +02:00
parent 7351fdaf9c
commit ddc3b24ebf
3 changed files with 81 additions and 39 deletions

View File

@ -1,10 +1,5 @@
Session is set to expire at %v
The session expires automatically after %v.
If there are no more sessions after logging out, the agent
terminates.
You can extend this time using
You can extend expiry of the session using
{{ if eq .os "windows" -}}
echo > %agentdir%\.hold
@ -13,7 +8,7 @@ You can extend this time using
{{- end }}
The expiry time is equal to the modification time of the .hold
file with the expiry duration added.
file with the expiry duration (%v) added.
To prevent the agent from exiting after the last session is gone,
also use the above command in any shell.

View File

@ -2,6 +2,7 @@ package agent
import (
"bytes"
"converge/pkg/async"
"fmt"
"github.com/fsnotify/fsnotify"
"github.com/gliderlabs/ssh"
@ -38,6 +39,9 @@ type AgentState struct {
// session expiry time
agentExpriryTime time.Duration
// Last expiry time reported to the user.
lastExpiryTimmeReported time.Time
// ticker
tickerInterval time.Duration
ticker *time.Ticker
@ -63,6 +67,11 @@ const holdFilename = ".hold"
var helpMessageTemplate string
var helpMessage = formatHelpMessage()
// Events channel for asynchronous events.
var events = make(chan func(), 10)
// External interface, asynchronous, apart from the initialization.
func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) {
if fileExists(holdFilename) {
log.Printf("Removing hold file '%s'", holdFilename)
@ -75,6 +84,7 @@ func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Dur
startTime: time.Now(),
advanceWarningTime: advanceWarningTime,
agentExpriryTime: agentExpiryTime,
lastExpiryTimmeReported: time.Time{},
tickerInterval: tickerInterval,
ticker: time.NewTicker(tickerInterval),
sessions: make(map[int]*AgentSession),
@ -84,11 +94,28 @@ func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Dur
go func() {
for {
<-state.ticker.C
check()
events <- async.Async(check)
}
}()
go monitorHoldFile()
go func() {
for {
event := <-events
event()
}
}()
}
func Login(sessionId int, sshSession ssh.Session) {
events <- async.Async(login, sessionId, sshSession)
}
func LogOut(sessionId int) {
events <- async.Async(logOut, sessionId)
}
// Internal interface synchronous
func monitorHoldFile() {
watcher, err := fsnotify.NewWatcher()
@ -100,7 +127,6 @@ func monitorHoldFile() {
if err != nil {
log.Printf("Cannot watch old file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err)
}
expiryTime := state.expiryTime(holdFilename)
for {
select {
case event, ok := <-watcher.Events:
@ -109,14 +135,7 @@ func monitorHoldFile() {
}
base := filepath.Base(event.Name)
if base == holdFilename {
newExpiryTIme := state.expiryTime(holdFilename)
if newExpiryTIme != expiryTime {
message := fmt.Sprintf("Expiry time of session is now %s\n",
newExpiryTIme.Format(time.DateTime))
message += holdFileMessage()
messageUsers(message)
expiryTime = newExpiryTIme
}
events <- async.Async(holdFileChange)
}
case err, ok := <-watcher.Errors:
@ -142,7 +161,7 @@ func holdFileMessage() string {
return message
}
func Login(sessionId int, sshSession ssh.Session) {
func login(sessionId int, sshSession ssh.Session) {
log.Println("New login")
hostname, _ := os.Hostname()
@ -160,21 +179,21 @@ func Login(sessionId int, sshSession ssh.Session) {
}
}
PrintMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
PrintHelpMessage(sshSession)
agentSession := AgentSession{
startTime: time.Now(),
sshSession: sshSession,
}
state.sessions[sessionId] = &agentSession
state.agentUsed = true
LogStatus()
logStatus()
printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
holdFileChange()
printHelpMessage(sshSession)
}
func PrintHelpMessage(sshSession ssh.Session) {
PrintMessage(sshSession, fmt.Sprintf(helpMessage,
state.expiryTime(holdFilename).Format(time.DateTime),
func printHelpMessage(sshSession ssh.Session) {
printMessage(sshSession, fmt.Sprintf(helpMessage,
state.agentExpriryTime))
}
@ -194,21 +213,21 @@ func formatHelpMessage() string {
return helpFormatted
}
func LogOut(sessionId int) {
func logOut(sessionId int) {
log.Println("User logged out")
delete(state.sessions, sessionId)
LogStatus()
logStatus()
check()
}
func PrintMessage(sshSession ssh.Session, message string) {
func printMessage(sshSession ssh.Session, message string) {
for _, line := range strings.Split(message, "\n") {
io.WriteString(sshSession.Stderr(), "### "+line+"\n\r")
}
io.WriteString(sshSession.Stderr(), "\n\r")
}
func LogStatus() {
func logStatus() {
fmt := "%-20s %-20s %-20s"
log.Println()
log.Printf(fmt, "UID", "START_TIME", "TYPE")
@ -242,6 +261,17 @@ func (state *AgentState) expiryTime(filename string) time.Time {
return stats.ModTime().Add(state.agentExpriryTime)
}
func holdFileChange() {
newExpiryTIme := state.expiryTime(holdFilename)
if newExpiryTIme != state.lastExpiryTimmeReported {
message := fmt.Sprintf("Expiry time of session is now %s\n",
newExpiryTIme.Format(time.DateTime))
message += holdFileMessage()
messageUsers(message)
state.lastExpiryTimmeReported = newExpiryTIme
}
}
// Behavior to implement
// 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime
// 2. The expiry time is relative to the modification time of the .hold file in the
@ -251,7 +281,6 @@ func (state *AgentState) expiryTime(filename string) time.Time {
// 4. If the last user logs out, the aagent will exit immediately if no .hold file is
// present. Otherwise it will exit after the epxiry time. This allows users to
// reconnect later.
func check() {
now := time.Now()
@ -268,7 +297,7 @@ func check() {
messageUsers(
fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime)))
for _, session := range state.sessions {
PrintHelpMessage(session.sshSession)
printHelpMessage(session.sshSession)
}
}
@ -281,6 +310,6 @@ func check() {
func messageUsers(message string) {
log.Printf("=== Notification to users: %s", message)
for _, session := range state.sessions {
PrintMessage(session.sshSession, message)
printMessage(session.sshSession, message)
}
}

18
pkg/async/async.go Normal file
View File

@ -0,0 +1,18 @@
package async
import "reflect"
func Async(fn interface{}, args ...interface{}) func() {
fnValue := reflect.ValueOf(fn)
// Prepare the arguments
params := make([]reflect.Value, len(args))
for i, arg := range args {
params[i] = reflect.ValueOf(arg)
}
// Return a function that, when called, will invoke the original function
return func() {
fnValue.Call(params)
}
}