converge/pkg/agent/session/session.go
2024-08-13 21:33:29 +02:00

380 lines
9.5 KiB
Go

package session
import (
"bytes"
"fmt"
"git.wamblee.org/converge/pkg/comms"
"github.com/fsnotify/fsnotify"
"github.com/gliderlabs/ssh"
"io"
"log"
"os"
"path/filepath"
"runtime"
"strings"
"text/template"
"time"
_ "embed"
)
// TDDO fix concurrency
// All methods put a message on a channel
//
// Using a channel of functions will work.
// When default is used, channel will block always and thereby
// effectively serializing everything.
//
// global configuration
type AgentState struct {
commChannel comms.CommChannel
startTime time.Time
// Advance warning time to notify the user of something important happening
advanceWarningDuration time.Duration
// session expiry time
agentExpiryDuration time.Duration
// Last expiry time reported to the user.
lastExpiryTimeReported time.Time
expiryIsNear bool
// ticker
tickerInterval time.Duration
ticker *time.Ticker
// map of unique session id to a session
clients map[string]*AgentSession
lastUserActivityTime time.Time
agentUsed bool
}
type AgentSession struct {
startTime time.Time
// For sending messages to the user
sshSession ssh.Session
}
var state AgentState
const holdFilename = ".hold"
//go:embed help.txt
var helpMessageTemplate string
var helpMessage = formatHelpMessage()
// Events channel for asynchronous events.
var events = make(chan func())
// External interface, asynchronous, apart from the initialization.
func ConfigureAgent(commChannel comms.CommChannel,
advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) {
if fileExists(holdFilename) {
log.Printf("Removing hold file '%s'", holdFilename)
err := os.Remove(holdFilename)
if err != nil {
log.Printf("Could not remove hold file: '%s'", holdFilename)
}
}
state = AgentState{
commChannel: commChannel,
startTime: time.Now(),
advanceWarningDuration: advanceWarningTime,
agentExpiryDuration: agentExpiryTime,
lastExpiryTimeReported: time.Time{},
tickerInterval: tickerInterval,
ticker: time.NewTicker(tickerInterval),
clients: make(map[string]*AgentSession),
lastUserActivityTime: time.Time{},
agentUsed: false,
}
log.Printf("Agent expiry duration is %v", state.agentExpiryDuration)
comms.Send(state.commChannel.SideChannel,
comms.ConvergeMessage{
Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)),
})
go func() {
for {
<-state.ticker.C
events <- check
}
}()
go monitorHoldFile()
go func() {
for {
event := <-events
event()
}
}()
}
func Login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
events <- func() {
login(sessionInfo, sshSession)
}
}
func LogOut(clientId string) {
events <- func() {
logOut(clientId)
}
}
func UserActivityDetected() {
events <- func() {
userActivityDetected()
}
}
func MessageUsers(message string) {
events <- func() {
messageUsers(message)
}
}
// Internal interface synchronous
func userActivityDetected() {
state.lastUserActivityTime = time.Now()
if state.expiryIsNear {
messageUsers("User activity detected, session extended.")
sendExpiryTimeUpdateEvent()
state.expiryIsNear = false
}
}
func monitorHoldFile() {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Printf("Cannot watch hold file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err)
}
defer watcher.Close()
err = watcher.Add(".")
if err != nil {
log.Printf("Cannot watch hold file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err)
}
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
base := filepath.Base(event.Name)
if base == holdFilename {
events <- holdFileChange
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
log.Println("Error:", err)
}
}
}
func holdFileMessage() string {
message := ""
if fileExists(holdFilename) {
message += fmt.Sprintf("When the last user exits, the session will timeout and not exit immediately.\n"+
"Remove the %s file if you want th session to terminate when the last user logs out",
holdFilename)
} else {
message += fmt.Sprintf("When the last user exits, the agent will return and the continuous\n" +
"integration job will continue.")
}
return message
}
func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
log.Println("New login")
hostname, _ := os.Hostname()
holdFileStats, ok := fileExistsWithStats(holdFilename)
if ok {
if holdFileStats.ModTime().After(time.Now()) {
// modification time in the future, leaving intact
log.Println("Hold file has modification time in the future, leaving intact")
} else {
log.Printf("Touching hold file '%s'", holdFilename)
err := os.Chtimes(holdFilename, time.Now(), time.Now())
if err != nil {
log.Printf("Could not touch hold file: '%s'", holdFilename)
}
}
}
agentSession := AgentSession{
startTime: time.Now(),
sshSession: sshSession,
}
state.clients[sessionInfo.ClientId] = &agentSession
state.lastUserActivityTime = time.Now()
state.agentUsed = true
err := comms.SendWithTimeout(state.commChannel.SideChannel,
comms.ConvergeMessage{Value: sessionInfo})
if err != nil {
log.Printf("Could not send session info to converge server, information on server may be incomplete %v", err)
}
logStatus()
printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
holdFileChange()
printHelpMessage(sshSession)
}
func printHelpMessage(sshSession ssh.Session) {
printMessage(sshSession, fmt.Sprintf(helpMessage,
state.agentExpiryDuration))
}
func formatHelpMessage() string {
templ, err := template.New("help").Parse(helpMessageTemplate)
if err != nil {
panic(err)
}
helpFormattedBuf := bytes.NewBuffer(make([]byte, 0))
log.Println("Running on ", runtime.GOOS)
data := map[string]string{"os": runtime.GOOS}
err = templ.Execute(helpFormattedBuf, data)
if err != nil {
panic(err)
}
helpFormatted := helpFormattedBuf.String()
return helpFormatted
}
func logOut(clientId string) {
log.Println("User logged out")
delete(state.clients, clientId)
logStatus()
check()
}
func printMessage(sshSession ssh.Session, message string) {
for _, line := range strings.Split(message, "\n") {
io.WriteString(sshSession.Stderr(), "### "+line+"\n\r")
}
io.WriteString(sshSession.Stderr(), "\n\r")
}
func logStatus() {
fmt := "%-20s %-20s %-20s"
log.Println()
log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
for uid, session := range state.clients {
sessionType := session.sshSession.Subsystem()
if sessionType == "" {
sessionType = "ssh"
}
log.Printf(fmt, uid,
session.startTime.Format(time.DateTime),
sessionType)
}
log.Println()
}
func fileExistsWithStats(filename string) (os.FileInfo, bool) {
stats, err := os.Stat(filename)
return stats, !os.IsNotExist(err)
}
func fileExists(filename string) bool {
_, err := os.Stat(filename)
return !os.IsNotExist(err)
}
func (state *AgentState) expiryTime(filename string) time.Time {
if !state.agentUsed {
return state.startTime.Add(state.agentExpiryDuration)
}
expiryTime := time.Time{}
stats, err := os.Stat(filename)
if err == nil {
expiryTime = stats.ModTime().Add(state.agentExpiryDuration)
}
userLoginBaseExpiryTime := state.lastUserActivityTime.Add(state.agentExpiryDuration)
if userLoginBaseExpiryTime.After(expiryTime) {
expiryTime = userLoginBaseExpiryTime
}
return expiryTime
}
func holdFileChange() {
newExpiryTime := state.expiryTime(holdFilename)
if newExpiryTime != state.lastExpiryTimeReported {
message := fmt.Sprintf("Expiry time of session is now %s\n",
newExpiryTime.Format(time.DateTime))
message += holdFileMessage()
messageUsers(message)
state.lastExpiryTimeReported = newExpiryTime
sendExpiryTimeUpdateEvent()
}
}
func sendExpiryTimeUpdateEvent() error {
return comms.Send(state.commChannel.SideChannel,
comms.ConvergeMessage{
Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)),
})
}
// Behavior to implement
// 1. there is a global timeout for all agent clients together: state.agentExpirtyTime
// 2. The expiry time is relative to the modification time of the .hold file in the
// agent directory or, if that file does not exist, the start time of the agent.
// 3. if we are close to the expiry time then we message users with instruction on
// how to prevent the timeout
// 4. If the last user logs out, the aagent will exit immediately if no .hold file is
// present. Otherwise it will exit after the epxiry time. This allows users to
// reconnect later.
func check() {
now := time.Now()
expiryTime := state.expiryTime(".hold")
if now.After(expiryTime) {
messageUsers("Expiry time was reached logging out")
time.Sleep(5 * time.Second)
log.Println("Agent exiting")
os.Exit(0)
}
state.expiryIsNear = expiryTime.Sub(now) < state.advanceWarningDuration
if state.expiryIsNear {
messageUsers(
fmt.Sprintf("Session will expire at %s, press any key to (ssh) or execute a command (sftp) to extend it.", expiryTime.Format(time.DateTime)))
//for _, session := range state.clients {
// printHelpMessage(session.sshSession)
//}
}
if state.agentUsed && !fileExists(holdFilename) && len(state.clients) == 0 {
log.Printf("All clients disconnected and no '%s' file found, exiting", holdFilename)
os.Exit(0)
}
sendExpiryTimeUpdateEvent()
}
func messageUsers(message string) {
log.Printf("=== Notification to users: %s", message)
for _, session := range state.clients {
printMessage(session.sshSession, message)
}
}