Alternative solution would be to run all initialization code in go routines to make it independent of initialization order but having a defined initialization order is much cleaner.
		
			
				
	
	
		
			380 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			380 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package session
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"fmt"
 | |
| 	"git.wamblee.org/converge/pkg/comms"
 | |
| 	"github.com/fsnotify/fsnotify"
 | |
| 	"github.com/gliderlabs/ssh"
 | |
| 	"io"
 | |
| 	"log"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"runtime"
 | |
| 	"strings"
 | |
| 	"text/template"
 | |
| 	"time"
 | |
| 
 | |
| 	_ "embed"
 | |
| )
 | |
| 
 | |
| // TDDO fix concurrency
 | |
| // All methods put a message on a channel
 | |
| //
 | |
| // Using a channel of functions will work.
 | |
| // When default is used, channel will block always and thereby
 | |
| // effectively serializing everything.
 | |
| //
 | |
| 
 | |
| // global configuration
 | |
| 
 | |
| type AgentState struct {
 | |
| 	commChannel comms.CommChannel
 | |
| 	startTime   time.Time
 | |
| 
 | |
| 	// Advance warning time to notify the user of something important happening
 | |
| 	advanceWarningDuration time.Duration
 | |
| 
 | |
| 	// session expiry time
 | |
| 	agentExpiryDuration time.Duration
 | |
| 
 | |
| 	// Last expiry time reported to the user.
 | |
| 	lastExpiryTimeReported time.Time
 | |
| 	expiryIsNear           bool
 | |
| 
 | |
| 	// ticker
 | |
| 	tickerInterval time.Duration
 | |
| 	ticker         *time.Ticker
 | |
| 
 | |
| 	// map of unique session id to a session
 | |
| 	clients map[string]*AgentSession
 | |
| 
 | |
| 	lastUserActivityTime time.Time
 | |
| 	agentUsed            bool
 | |
| }
 | |
| 
 | |
| type AgentSession struct {
 | |
| 	startTime time.Time
 | |
| 
 | |
| 	// For sending messages to the user
 | |
| 	sshSession ssh.Session
 | |
| }
 | |
| 
 | |
| var state AgentState
 | |
| 
 | |
| const holdFilename = ".hold"
 | |
| 
 | |
| //go:embed help.txt
 | |
| var helpMessageTemplate string
 | |
| var helpMessage = formatHelpMessage()
 | |
| 
 | |
| // Events channel for asynchronous events.
 | |
| var events = make(chan func())
 | |
| 
 | |
| // External interface, asynchronous, apart from the initialization.
 | |
| 
 | |
| func ConfigureAgent(commChannel comms.CommChannel,
 | |
| 	advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) {
 | |
| 	if fileExists(holdFilename) {
 | |
| 		log.Printf("Removing hold file '%s'", holdFilename)
 | |
| 		err := os.Remove(holdFilename)
 | |
| 		if err != nil {
 | |
| 			log.Printf("Could not remove hold file: '%s'", holdFilename)
 | |
| 		}
 | |
| 	}
 | |
| 	state = AgentState{
 | |
| 		commChannel:            commChannel,
 | |
| 		startTime:              time.Now(),
 | |
| 		advanceWarningDuration: advanceWarningTime,
 | |
| 		agentExpiryDuration:    agentExpiryTime,
 | |
| 		lastExpiryTimeReported: time.Time{},
 | |
| 		tickerInterval:         tickerInterval,
 | |
| 		ticker:                 time.NewTicker(tickerInterval),
 | |
| 		clients:                make(map[string]*AgentSession),
 | |
| 
 | |
| 		lastUserActivityTime: time.Time{},
 | |
| 		agentUsed:            false,
 | |
| 	}
 | |
| 
 | |
| 	log.Printf("Agent expiry duration is %v", state.agentExpiryDuration)
 | |
| 
 | |
| 	comms.Send(state.commChannel.SideChannel,
 | |
| 		comms.ConvergeMessage{
 | |
| 			Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)),
 | |
| 		})
 | |
| 
 | |
| 	go func() {
 | |
| 		for {
 | |
| 			<-state.ticker.C
 | |
| 			events <- check
 | |
| 		}
 | |
| 	}()
 | |
| 	go monitorHoldFile()
 | |
| 	go func() {
 | |
| 		for {
 | |
| 			event := <-events
 | |
| 			event()
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| }
 | |
| 
 | |
| func Login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
 | |
| 	events <- func() {
 | |
| 		login(sessionInfo, sshSession)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func LogOut(clientId string) {
 | |
| 	events <- func() {
 | |
| 		logOut(clientId)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func UserActivityDetected() {
 | |
| 	events <- func() {
 | |
| 		userActivityDetected()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func MessageUsers(message string) {
 | |
| 	events <- func() {
 | |
| 		messageUsers(message)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Internal interface synchronous
 | |
| 
 | |
| func userActivityDetected() {
 | |
| 	state.lastUserActivityTime = time.Now()
 | |
| 	if state.expiryIsNear {
 | |
| 		messageUsers("User activity detected, session extended.")
 | |
| 		sendExpiryTimeUpdateEvent()
 | |
| 		state.expiryIsNear = false
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func monitorHoldFile() {
 | |
| 	watcher, err := fsnotify.NewWatcher()
 | |
| 	if err != nil {
 | |
| 		log.Printf("Cannot watch hold file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err)
 | |
| 	}
 | |
| 	defer watcher.Close()
 | |
| 	err = watcher.Add(".")
 | |
| 	if err != nil {
 | |
| 		log.Printf("Cannot watch hold file %s, user notifications for change in expiry time will be unavailable: %v", holdFilename, err)
 | |
| 	}
 | |
| 	for {
 | |
| 		select {
 | |
| 		case event, ok := <-watcher.Events:
 | |
| 			if !ok {
 | |
| 				return
 | |
| 			}
 | |
| 			base := filepath.Base(event.Name)
 | |
| 			if base == holdFilename {
 | |
| 				events <- holdFileChange
 | |
| 			}
 | |
| 
 | |
| 		case err, ok := <-watcher.Errors:
 | |
| 			if !ok {
 | |
| 				return
 | |
| 			}
 | |
| 			log.Println("Error:", err)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func holdFileMessage() string {
 | |
| 	message := ""
 | |
| 	if fileExists(holdFilename) {
 | |
| 		message += fmt.Sprintf("When the last user exits, the session will timeout and not exit immediately.\n"+
 | |
| 			"Remove the %s file if you want th session to terminate when the last user logs out",
 | |
| 			holdFilename)
 | |
| 	} else {
 | |
| 		message += fmt.Sprintf("When the last user exits, the agent will return and the continuous\n" +
 | |
| 			"integration job will continue.")
 | |
| 
 | |
| 	}
 | |
| 	return message
 | |
| }
 | |
| 
 | |
| func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
 | |
| 	log.Println("New login")
 | |
| 	hostname, _ := os.Hostname()
 | |
| 
 | |
| 	holdFileStats, ok := fileExistsWithStats(holdFilename)
 | |
| 	if ok {
 | |
| 		if holdFileStats.ModTime().After(time.Now()) {
 | |
| 			// modification time in the future, leaving intact
 | |
| 			log.Println("Hold file has modification time in the future, leaving intact")
 | |
| 		} else {
 | |
| 			log.Printf("Touching hold file '%s'", holdFilename)
 | |
| 			err := os.Chtimes(holdFilename, time.Now(), time.Now())
 | |
| 			if err != nil {
 | |
| 				log.Printf("Could not touch hold file: '%s'", holdFilename)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	agentSession := AgentSession{
 | |
| 		startTime:  time.Now(),
 | |
| 		sshSession: sshSession,
 | |
| 	}
 | |
| 	state.clients[sessionInfo.ClientId] = &agentSession
 | |
| 	state.lastUserActivityTime = time.Now()
 | |
| 	state.agentUsed = true
 | |
| 
 | |
| 	err := comms.SendWithTimeout(state.commChannel.SideChannel,
 | |
| 		comms.ConvergeMessage{Value: sessionInfo})
 | |
| 	if err != nil {
 | |
| 		log.Printf("Could not send session info to converge server, information on server may be incomplete %v", err)
 | |
| 	}
 | |
| 
 | |
| 	logStatus()
 | |
| 
 | |
| 	printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
 | |
| 	holdFileChange()
 | |
| 	printHelpMessage(sshSession)
 | |
| }
 | |
| 
 | |
| func printHelpMessage(sshSession ssh.Session) {
 | |
| 	printMessage(sshSession, fmt.Sprintf(helpMessage,
 | |
| 		state.agentExpiryDuration))
 | |
| }
 | |
| 
 | |
| func formatHelpMessage() string {
 | |
| 	templ, err := template.New("help").Parse(helpMessageTemplate)
 | |
| 	if err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 	helpFormattedBuf := bytes.NewBuffer(make([]byte, 0))
 | |
| 	log.Println("Running on ", runtime.GOOS)
 | |
| 	data := map[string]string{"os": runtime.GOOS}
 | |
| 	err = templ.Execute(helpFormattedBuf, data)
 | |
| 	if err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 	helpFormatted := helpFormattedBuf.String()
 | |
| 	return helpFormatted
 | |
| }
 | |
| 
 | |
| func logOut(clientId string) {
 | |
| 	log.Println("User logged out")
 | |
| 	delete(state.clients, clientId)
 | |
| 	logStatus()
 | |
| 	check()
 | |
| }
 | |
| 
 | |
| func printMessage(sshSession ssh.Session, message string) {
 | |
| 	for _, line := range strings.Split(message, "\n") {
 | |
| 		io.WriteString(sshSession.Stderr(), "### "+line+"\n\r")
 | |
| 	}
 | |
| 	io.WriteString(sshSession.Stderr(), "\n\r")
 | |
| }
 | |
| 
 | |
| func logStatus() {
 | |
| 	fmt := "%-20s %-20s %-20s"
 | |
| 	log.Println()
 | |
| 	log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
 | |
| 	for uid, session := range state.clients {
 | |
| 		sessionType := session.sshSession.Subsystem()
 | |
| 		if sessionType == "" {
 | |
| 			sessionType = "ssh"
 | |
| 		}
 | |
| 		log.Printf(fmt, uid,
 | |
| 			session.startTime.Format(time.DateTime),
 | |
| 			sessionType)
 | |
| 	}
 | |
| 	log.Println()
 | |
| }
 | |
| 
 | |
| func fileExistsWithStats(filename string) (os.FileInfo, bool) {
 | |
| 	stats, err := os.Stat(filename)
 | |
| 	return stats, !os.IsNotExist(err)
 | |
| }
 | |
| 
 | |
| func fileExists(filename string) bool {
 | |
| 	_, err := os.Stat(filename)
 | |
| 	return !os.IsNotExist(err)
 | |
| }
 | |
| 
 | |
| func (state *AgentState) expiryTime(filename string) time.Time {
 | |
| 	if !state.agentUsed {
 | |
| 		return state.startTime.Add(state.agentExpiryDuration)
 | |
| 	}
 | |
| 	expiryTime := time.Time{}
 | |
| 	stats, err := os.Stat(filename)
 | |
| 	if err == nil {
 | |
| 		expiryTime = stats.ModTime().Add(state.agentExpiryDuration)
 | |
| 	}
 | |
| 	userLoginBaseExpiryTime := state.lastUserActivityTime.Add(state.agentExpiryDuration)
 | |
| 	if userLoginBaseExpiryTime.After(expiryTime) {
 | |
| 		expiryTime = userLoginBaseExpiryTime
 | |
| 	}
 | |
| 	return expiryTime
 | |
| }
 | |
| 
 | |
| func holdFileChange() {
 | |
| 	newExpiryTime := state.expiryTime(holdFilename)
 | |
| 	if newExpiryTime != state.lastExpiryTimeReported {
 | |
| 		message := fmt.Sprintf("Expiry time of session is now %s\n",
 | |
| 			newExpiryTime.Format(time.DateTime))
 | |
| 		message += holdFileMessage()
 | |
| 		messageUsers(message)
 | |
| 		state.lastExpiryTimeReported = newExpiryTime
 | |
| 		sendExpiryTimeUpdateEvent()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func sendExpiryTimeUpdateEvent() error {
 | |
| 	return comms.Send(state.commChannel.SideChannel,
 | |
| 		comms.ConvergeMessage{
 | |
| 			Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)),
 | |
| 		})
 | |
| }
 | |
| 
 | |
| // Behavior to implement
 | |
| //  1. there is a global timeout for all agent clients together: state.agentExpirtyTime
 | |
| //  2. The expiry time is relative to the modification time of the .hold file in the
 | |
| //     agent directory or, if that file does not exist, the start time of the agent.
 | |
| //  3. if we are close to the expiry time then we message users with instruction on
 | |
| //     how to prevent the timeout
 | |
| //  4. If the last user logs out, the aagent will exit immediately if no .hold file is
 | |
| //     present. Otherwise it will exit after the epxiry time. This allows users to
 | |
| //     reconnect later.
 | |
| func check() {
 | |
| 	now := time.Now()
 | |
| 
 | |
| 	expiryTime := state.expiryTime(".hold")
 | |
| 
 | |
| 	if now.After(expiryTime) {
 | |
| 		messageUsers("Expiry time was reached logging out")
 | |
| 		time.Sleep(5 * time.Second)
 | |
| 		log.Println("Agent exiting")
 | |
| 		os.Exit(0)
 | |
| 	}
 | |
| 
 | |
| 	state.expiryIsNear = expiryTime.Sub(now) < state.advanceWarningDuration
 | |
| 	if state.expiryIsNear {
 | |
| 		messageUsers(
 | |
| 			fmt.Sprintf("Session will expire at %s, press any key to (ssh) or execute a command (sftp) to extend it.", expiryTime.Format(time.DateTime)))
 | |
| 		//for _, session := range state.clients {
 | |
| 		//	printHelpMessage(session.sshSession)
 | |
| 		//}
 | |
| 	}
 | |
| 
 | |
| 	if state.agentUsed && !fileExists(holdFilename) && len(state.clients) == 0 {
 | |
| 		log.Printf("All clients disconnected and no '%s' file found, exiting", holdFilename)
 | |
| 		os.Exit(0)
 | |
| 	}
 | |
| 
 | |
| 	sendExpiryTimeUpdateEvent()
 | |
| }
 | |
| 
 | |
| func messageUsers(message string) {
 | |
| 	log.Printf("=== Notification to users: %s", message)
 | |
| 	for _, session := range state.clients {
 | |
| 		printMessage(session.sshSession, message)
 | |
| 	}
 | |
| }
 |