converge/pkg/agent/session.go
Erik Brakkee d3cbf8388f Lots of refactoring.
Now hijacking the ssh connection setup in the listener to exchange some information before passing the connection on to the SSH server.

Next step is to do the full exchange of required information and to make it easy some simple Read and Write methods with timeouts are needed that use gob.
2024-09-08 11:16:49 +02:00

345 lines
8.7 KiB
Go

package agent
import (
"bytes"
"converge/pkg/async"
"converge/pkg/comms"
"fmt"
"github.com/fsnotify/fsnotify"
"github.com/gliderlabs/ssh"
"io"
"log"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"text/template"
"time"
_ "embed"
)
// TDDO fix concurrency
// All methods put a message on a channel
//
// Using a channel of functions will work.
// When default is used, channel will block always and thereby
// effectively serializing everything.
//
// make(chan func())
// global configuration
type AgentState struct {
commChannel comms.CommChannel
startTime time.Time
// Advance warning time to notify the user of something important happening
advanceWarningTime time.Duration
// session expiry time
agentExpriryTime time.Duration
// Last expiry time reported to the user.
lastExpiryTimmeReported time.Time
// ticker
tickerInterval time.Duration
ticker *time.Ticker
// map of unique session id to a session
sessions map[int]*AgentSession
lastUserLoginTime time.Time
agentUsed bool
}
type AgentSession struct {
startTime time.Time
// For sending messages to the user
sshSession ssh.Session
}
var state AgentState
const holdFilename = ".hold"
//go:embed help.txt
var helpMessageTemplate string
var helpMessage = formatHelpMessage()
// Events channel for asynchronous events.
var events = make(chan func(), 10)
// External interface, asynchronous, apart from the initialization.
func ConfigureAgent(commChannel comms.CommChannel,
advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) {
if fileExists(holdFilename) {
log.Printf("Removing hold file '%s'", holdFilename)
err := os.Remove(holdFilename)
if err != nil {
log.Printf("Could not remove hold file: '%s'", holdFilename)
}
}
state = AgentState{
commChannel: commChannel,
startTime: time.Now(),
advanceWarningTime: advanceWarningTime,
agentExpriryTime: agentExpiryTime,
lastExpiryTimmeReported: time.Time{},
tickerInterval: tickerInterval,
ticker: time.NewTicker(tickerInterval),
sessions: make(map[int]*AgentSession),
lastUserLoginTime: time.Time{},
agentUsed: false,
}
log.Printf("Agent expires at %s",
state.expiryTime(holdFilename).Format(time.DateTime))
state.commChannel.SideChannel.Send(comms.NewAgentInfo())
state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
go func() {
for {
<-state.ticker.C
events <- async.Async(check)
}
}()
go monitorHoldFile()
go func() {
for {
event := <-events
event()
}
}()
}
func Login(sessionId int, sshSession ssh.Session) {
events <- async.Async(login, sessionId, sshSession)
}
func LogOut(sessionId int) {
events <- async.Async(logOut, sessionId)
}
// Internal interface synchronous
func monitorHoldFile() {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Printf("Cannot watch old file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err)
}
defer watcher.Close()
err = watcher.Add(".")
if err != nil {
log.Printf("Cannot watch old file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err)
}
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
base := filepath.Base(event.Name)
if base == holdFilename {
events <- async.Async(holdFileChange)
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
log.Println("Error:", err)
}
}
}
func holdFileMessage() string {
message := ""
if fileExists(holdFilename) {
message += fmt.Sprintf("When the last user exits, the session will timeout and not exit immediately.\n"+
"Remove the %s file if you want th session to terminate when the last user logs out",
holdFilename)
} else {
message += fmt.Sprintf("When the last user exits, the agent will return and the continuous\n" +
"integration job will continue.")
}
return message
}
func login(sessionId int, sshSession ssh.Session) {
log.Println("New login")
hostname, _ := os.Hostname()
sessionType := sshSession.Subsystem()
if sessionType == "" {
sessionType = "ssh"
}
state.commChannel.SideChannel.Send(comms.NewSessionInfo(sessionType))
holdFileStats, ok := fileExistsWithStats(holdFilename)
if ok {
if holdFileStats.ModTime().After(time.Now()) {
// modification time in the future, leaving intact
log.Println("Hold file has modification time in the future, leaving intact")
} else {
log.Printf("Touching hold file '%s'", holdFilename)
err := os.Chtimes(holdFilename, time.Now(), time.Now())
if err != nil {
log.Printf("Could not touch hold file: '%s'", holdFilename)
}
}
}
agentSession := AgentSession{
startTime: time.Now(),
sshSession: sshSession,
}
state.sessions[sessionId] = &agentSession
state.lastUserLoginTime = time.Now()
state.agentUsed = true
logStatus()
printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
holdFileChange()
printHelpMessage(sshSession)
}
func printHelpMessage(sshSession ssh.Session) {
printMessage(sshSession, fmt.Sprintf(helpMessage,
state.agentExpriryTime))
}
func formatHelpMessage() string {
templ, err := template.New("help").Parse(helpMessageTemplate)
if err != nil {
panic(err)
}
helpFormattedBuf := bytes.NewBuffer(make([]byte, 0))
log.Println("Running on ", runtime.GOOS)
data := map[string]string{"os": runtime.GOOS}
err = templ.Execute(helpFormattedBuf, data)
if err != nil {
panic(err)
}
helpFormatted := helpFormattedBuf.String()
return helpFormatted
}
func logOut(sessionId int) {
log.Println("User logged out")
delete(state.sessions, sessionId)
logStatus()
check()
}
func printMessage(sshSession ssh.Session, message string) {
for _, line := range strings.Split(message, "\n") {
io.WriteString(sshSession.Stderr(), "### "+line+"\n\r")
}
io.WriteString(sshSession.Stderr(), "\n\r")
}
func logStatus() {
fmt := "%-20s %-20s %-20s"
log.Println()
log.Printf(fmt, "UID", "START_TIME", "TYPE")
for uid, session := range state.sessions {
sessionType := session.sshSession.Subsystem()
if sessionType == "" {
sessionType = "ssh"
}
log.Printf(fmt, strconv.Itoa(uid),
session.startTime.Format(time.DateTime),
sessionType)
}
log.Println()
}
func fileExistsWithStats(filename string) (os.FileInfo, bool) {
stats, err := os.Stat(filename)
return stats, !os.IsNotExist(err)
}
func fileExists(filename string) bool {
_, err := os.Stat(filename)
return !os.IsNotExist(err)
}
func (state *AgentState) expiryTime(filename string) time.Time {
if !state.agentUsed {
return state.startTime.Add(state.agentExpriryTime)
}
expiryTime := time.Time{}
stats, err := os.Stat(filename)
if err == nil {
expiryTime = stats.ModTime().Add(state.agentExpriryTime)
}
userLoginBaseExpiryTime := state.lastUserLoginTime.Add(state.agentExpriryTime)
if userLoginBaseExpiryTime.After(expiryTime) {
expiryTime = userLoginBaseExpiryTime
}
return expiryTime
}
func holdFileChange() {
newExpiryTIme := state.expiryTime(holdFilename)
if newExpiryTIme != state.lastExpiryTimmeReported {
message := fmt.Sprintf("Expiry time of session is now %s\n",
newExpiryTIme.Format(time.DateTime))
message += holdFileMessage()
messageUsers(message)
state.lastExpiryTimmeReported = newExpiryTIme
state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
}
}
// Behavior to implement
// 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime
// 2. The expiry time is relative to the modification time of the .hold file in the
// agent directory or, if that file does not exist, the start time of the agent.
// 3. if we are close to the expiry time then we message users with instruction on
// how to prevent the timeout
// 4. If the last user logs out, the aagent will exit immediately if no .hold file is
// present. Otherwise it will exit after the epxiry time. This allows users to
// reconnect later.
func check() {
now := time.Now()
expiryTime := state.expiryTime(".hold")
if now.After(expiryTime) {
messageUsers("Expiry time was reached logging out")
time.Sleep(5 * time.Second)
log.Println("Agent exiting")
os.Exit(0)
}
if expiryTime.Sub(now) < state.advanceWarningTime {
messageUsers(
fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime)))
for _, session := range state.sessions {
printHelpMessage(session.sshSession)
}
}
if state.agentUsed && !fileExists(holdFilename) && len(state.sessions) == 0 {
log.Printf("All clients disconnected and no '%s' file found, exiting", holdFilename)
os.Exit(0)
}
}
func messageUsers(message string) {
log.Printf("=== Notification to users: %s", message)
for _, session := range state.sessions {
printMessage(session.sshSession, message)
}
}