clean solution for concurrence in session.go by serializing all external calls (apart from initialization) through a channel.
This commit is contained in:
parent
bdedef12f0
commit
689c8e63b4
@ -1,10 +1,5 @@
|
|||||||
Session is set to expire at %v
|
|
||||||
|
|
||||||
The session expires automatically after %v.
|
You can extend expiry of the session using
|
||||||
If there are no more sessions after logging out, the agent
|
|
||||||
terminates.
|
|
||||||
|
|
||||||
You can extend this time using
|
|
||||||
|
|
||||||
{{ if eq .os "windows" -}}
|
{{ if eq .os "windows" -}}
|
||||||
echo > %agentdir%\.hold
|
echo > %agentdir%\.hold
|
||||||
@ -13,7 +8,7 @@ You can extend this time using
|
|||||||
{{- end }}
|
{{- end }}
|
||||||
|
|
||||||
The expiry time is equal to the modification time of the .hold
|
The expiry time is equal to the modification time of the .hold
|
||||||
file with the expiry duration added.
|
file with the expiry duration (%v) added.
|
||||||
|
|
||||||
To prevent the agent from exiting after the last session is gone,
|
To prevent the agent from exiting after the last session is gone,
|
||||||
also use the above command in any shell.
|
also use the above command in any shell.
|
||||||
|
@ -2,6 +2,7 @@ package agent
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"converge/pkg/async"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/gliderlabs/ssh"
|
"github.com/gliderlabs/ssh"
|
||||||
@ -38,6 +39,9 @@ type AgentState struct {
|
|||||||
// session expiry time
|
// session expiry time
|
||||||
agentExpriryTime time.Duration
|
agentExpriryTime time.Duration
|
||||||
|
|
||||||
|
// Last expiry time reported to the user.
|
||||||
|
lastExpiryTimmeReported time.Time
|
||||||
|
|
||||||
// ticker
|
// ticker
|
||||||
tickerInterval time.Duration
|
tickerInterval time.Duration
|
||||||
ticker *time.Ticker
|
ticker *time.Ticker
|
||||||
@ -63,6 +67,11 @@ const holdFilename = ".hold"
|
|||||||
var helpMessageTemplate string
|
var helpMessageTemplate string
|
||||||
var helpMessage = formatHelpMessage()
|
var helpMessage = formatHelpMessage()
|
||||||
|
|
||||||
|
// Events channel for asynchronous events.
|
||||||
|
var events = make(chan func(), 10)
|
||||||
|
|
||||||
|
// External interface, asynchronous, apart from the initialization.
|
||||||
|
|
||||||
func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) {
|
func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) {
|
||||||
if fileExists(holdFilename) {
|
if fileExists(holdFilename) {
|
||||||
log.Printf("Removing hold file '%s'", holdFilename)
|
log.Printf("Removing hold file '%s'", holdFilename)
|
||||||
@ -75,6 +84,7 @@ func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Dur
|
|||||||
startTime: time.Now(),
|
startTime: time.Now(),
|
||||||
advanceWarningTime: advanceWarningTime,
|
advanceWarningTime: advanceWarningTime,
|
||||||
agentExpriryTime: agentExpiryTime,
|
agentExpriryTime: agentExpiryTime,
|
||||||
|
lastExpiryTimmeReported: time.Time{},
|
||||||
tickerInterval: tickerInterval,
|
tickerInterval: tickerInterval,
|
||||||
ticker: time.NewTicker(tickerInterval),
|
ticker: time.NewTicker(tickerInterval),
|
||||||
sessions: make(map[int]*AgentSession),
|
sessions: make(map[int]*AgentSession),
|
||||||
@ -84,12 +94,29 @@ func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Dur
|
|||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
<-state.ticker.C
|
<-state.ticker.C
|
||||||
check()
|
events <- async.Async(check)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
go monitorHoldFile()
|
go monitorHoldFile()
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
event := <-events
|
||||||
|
event()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Login(sessionId int, sshSession ssh.Session) {
|
||||||
|
events <- async.Async(login, sessionId, sshSession)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogOut(sessionId int) {
|
||||||
|
events <- async.Async(logOut, sessionId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Internal interface synchronous
|
||||||
|
|
||||||
func monitorHoldFile() {
|
func monitorHoldFile() {
|
||||||
watcher, err := fsnotify.NewWatcher()
|
watcher, err := fsnotify.NewWatcher()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -100,7 +127,6 @@ func monitorHoldFile() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot watch old file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err)
|
log.Printf("Cannot watch old file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err)
|
||||||
}
|
}
|
||||||
expiryTime := state.expiryTime(holdFilename)
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case event, ok := <-watcher.Events:
|
case event, ok := <-watcher.Events:
|
||||||
@ -109,14 +135,7 @@ func monitorHoldFile() {
|
|||||||
}
|
}
|
||||||
base := filepath.Base(event.Name)
|
base := filepath.Base(event.Name)
|
||||||
if base == holdFilename {
|
if base == holdFilename {
|
||||||
newExpiryTIme := state.expiryTime(holdFilename)
|
events <- async.Async(holdFileChange)
|
||||||
if newExpiryTIme != expiryTime {
|
|
||||||
message := fmt.Sprintf("Expiry time of session is now %s\n",
|
|
||||||
newExpiryTIme.Format(time.DateTime))
|
|
||||||
message += holdFileMessage()
|
|
||||||
messageUsers(message)
|
|
||||||
expiryTime = newExpiryTIme
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case err, ok := <-watcher.Errors:
|
case err, ok := <-watcher.Errors:
|
||||||
@ -142,7 +161,7 @@ func holdFileMessage() string {
|
|||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
func Login(sessionId int, sshSession ssh.Session) {
|
func login(sessionId int, sshSession ssh.Session) {
|
||||||
log.Println("New login")
|
log.Println("New login")
|
||||||
hostname, _ := os.Hostname()
|
hostname, _ := os.Hostname()
|
||||||
|
|
||||||
@ -160,21 +179,21 @@ func Login(sessionId int, sshSession ssh.Session) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PrintMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
|
|
||||||
PrintHelpMessage(sshSession)
|
|
||||||
|
|
||||||
agentSession := AgentSession{
|
agentSession := AgentSession{
|
||||||
startTime: time.Now(),
|
startTime: time.Now(),
|
||||||
sshSession: sshSession,
|
sshSession: sshSession,
|
||||||
}
|
}
|
||||||
state.sessions[sessionId] = &agentSession
|
state.sessions[sessionId] = &agentSession
|
||||||
state.agentUsed = true
|
state.agentUsed = true
|
||||||
LogStatus()
|
logStatus()
|
||||||
|
|
||||||
|
printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
|
||||||
|
holdFileChange()
|
||||||
|
printHelpMessage(sshSession)
|
||||||
}
|
}
|
||||||
|
|
||||||
func PrintHelpMessage(sshSession ssh.Session) {
|
func printHelpMessage(sshSession ssh.Session) {
|
||||||
PrintMessage(sshSession, fmt.Sprintf(helpMessage,
|
printMessage(sshSession, fmt.Sprintf(helpMessage,
|
||||||
state.expiryTime(holdFilename).Format(time.DateTime),
|
|
||||||
state.agentExpriryTime))
|
state.agentExpriryTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,21 +213,21 @@ func formatHelpMessage() string {
|
|||||||
return helpFormatted
|
return helpFormatted
|
||||||
}
|
}
|
||||||
|
|
||||||
func LogOut(sessionId int) {
|
func logOut(sessionId int) {
|
||||||
log.Println("User logged out")
|
log.Println("User logged out")
|
||||||
delete(state.sessions, sessionId)
|
delete(state.sessions, sessionId)
|
||||||
LogStatus()
|
logStatus()
|
||||||
check()
|
check()
|
||||||
}
|
}
|
||||||
|
|
||||||
func PrintMessage(sshSession ssh.Session, message string) {
|
func printMessage(sshSession ssh.Session, message string) {
|
||||||
for _, line := range strings.Split(message, "\n") {
|
for _, line := range strings.Split(message, "\n") {
|
||||||
io.WriteString(sshSession.Stderr(), "### "+line+"\n\r")
|
io.WriteString(sshSession.Stderr(), "### "+line+"\n\r")
|
||||||
}
|
}
|
||||||
io.WriteString(sshSession.Stderr(), "\n\r")
|
io.WriteString(sshSession.Stderr(), "\n\r")
|
||||||
}
|
}
|
||||||
|
|
||||||
func LogStatus() {
|
func logStatus() {
|
||||||
fmt := "%-20s %-20s %-20s"
|
fmt := "%-20s %-20s %-20s"
|
||||||
log.Println()
|
log.Println()
|
||||||
log.Printf(fmt, "UID", "START_TIME", "TYPE")
|
log.Printf(fmt, "UID", "START_TIME", "TYPE")
|
||||||
@ -242,6 +261,17 @@ func (state *AgentState) expiryTime(filename string) time.Time {
|
|||||||
return stats.ModTime().Add(state.agentExpriryTime)
|
return stats.ModTime().Add(state.agentExpriryTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func holdFileChange() {
|
||||||
|
newExpiryTIme := state.expiryTime(holdFilename)
|
||||||
|
if newExpiryTIme != state.lastExpiryTimmeReported {
|
||||||
|
message := fmt.Sprintf("Expiry time of session is now %s\n",
|
||||||
|
newExpiryTIme.Format(time.DateTime))
|
||||||
|
message += holdFileMessage()
|
||||||
|
messageUsers(message)
|
||||||
|
state.lastExpiryTimmeReported = newExpiryTIme
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Behavior to implement
|
// Behavior to implement
|
||||||
// 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime
|
// 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime
|
||||||
// 2. The expiry time is relative to the modification time of the .hold file in the
|
// 2. The expiry time is relative to the modification time of the .hold file in the
|
||||||
@ -251,7 +281,6 @@ func (state *AgentState) expiryTime(filename string) time.Time {
|
|||||||
// 4. If the last user logs out, the aagent will exit immediately if no .hold file is
|
// 4. If the last user logs out, the aagent will exit immediately if no .hold file is
|
||||||
// present. Otherwise it will exit after the epxiry time. This allows users to
|
// present. Otherwise it will exit after the epxiry time. This allows users to
|
||||||
// reconnect later.
|
// reconnect later.
|
||||||
|
|
||||||
func check() {
|
func check() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
@ -268,7 +297,7 @@ func check() {
|
|||||||
messageUsers(
|
messageUsers(
|
||||||
fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime)))
|
fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime)))
|
||||||
for _, session := range state.sessions {
|
for _, session := range state.sessions {
|
||||||
PrintHelpMessage(session.sshSession)
|
printHelpMessage(session.sshSession)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -281,6 +310,6 @@ func check() {
|
|||||||
func messageUsers(message string) {
|
func messageUsers(message string) {
|
||||||
log.Printf("=== Notification to users: %s", message)
|
log.Printf("=== Notification to users: %s", message)
|
||||||
for _, session := range state.sessions {
|
for _, session := range state.sessions {
|
||||||
PrintMessage(session.sshSession, message)
|
printMessage(session.sshSession, message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
18
pkg/async/async.go
Normal file
18
pkg/async/async.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package async
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
func Async(fn interface{}, args ...interface{}) func() {
|
||||||
|
fnValue := reflect.ValueOf(fn)
|
||||||
|
|
||||||
|
// Prepare the arguments
|
||||||
|
params := make([]reflect.Value, len(args))
|
||||||
|
for i, arg := range args {
|
||||||
|
params[i] = reflect.ValueOf(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a function that, when called, will invoke the original function
|
||||||
|
return func() {
|
||||||
|
fnValue.Call(params)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user