simple session management solution with a .hold file and messages to the
user with better formatting.
This commit is contained in:
		
							parent
							
								
									134c72d8d0
								
							
						
					
					
						commit
						22a3589d1d
					
				| @ -60,12 +60,14 @@ func setWinsize(f *os.File, w, h int) { | ||||
| 
 | ||||
| func sshServer(hostKeyFile string) *ssh.Server { | ||||
| 	ssh.Handle(func(s ssh.Session) { | ||||
| 		hostname, _ := os.Hostname() | ||||
| 		io.WriteString(s, fmt.Sprintf("Your are now on %s\n\n", hostname)) | ||||
| 		// TODO shell should  be made configurable
 | ||||
| 		cmd := exec.Command("bash") | ||||
| 		ptyReq, winCh, isPty := s.Pty() | ||||
| 		if isPty { | ||||
| 			cmd.Env = append(os.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term)) | ||||
| 			workingDirectory, _ := os.Getwd() | ||||
| 			cmd.Env = append(os.Environ(), | ||||
| 				fmt.Sprintf("TERM=%s", ptyReq.Term), | ||||
| 				fmt.Sprintf("agentdir=%s", workingDirectory)) | ||||
| 			f, err := pty.Start(cmd) | ||||
| 			if err != nil { | ||||
| 				panic(err) | ||||
| @ -83,7 +85,6 @@ func sshServer(hostKeyFile string) *ssh.Server { | ||||
| 			}() | ||||
| 			io.Copy(s, f) // stdout
 | ||||
| 			cmd.Wait() | ||||
| 			log.Println("User logged out") | ||||
| 			agent.LogOut(uid) | ||||
| 		} else { | ||||
| 			io.WriteString(s, "No PTY requested.\n") | ||||
| @ -154,9 +155,9 @@ func main() { | ||||
| 	wsURL := os.Args[1] | ||||
| 
 | ||||
| 	advanceWarningTime := 1 * time.Minute | ||||
| 	sessionExpiryTime := 5 * time.Minute | ||||
| 	agentExpriryTime := 2 * time.Minute | ||||
| 	tickerInterval := 10 * time.Second | ||||
| 	agent.ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval) | ||||
| 	agent.ConfigureAgent(advanceWarningTime, agentExpriryTime, tickerInterval) | ||||
| 
 | ||||
| 	conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) | ||||
| 	if err != nil { | ||||
|  | ||||
							
								
								
									
										13
									
								
								pkg/agent/help.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								pkg/agent/help.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | ||||
| Session is set to expire at %s | ||||
| 
 | ||||
| The session expires automatically after %d time. | ||||
| If there are no more sessions after logging out, the agent | ||||
| terminates. | ||||
| 
 | ||||
| You can extend this time using | ||||
| 
 | ||||
|   touch $agentdir/.hold | ||||
| 
 | ||||
| To prevent the agent from exiting after the last sessioni is gone, | ||||
| also use the above command in any shell. | ||||
| 
 | ||||
| @ -1,22 +1,28 @@ | ||||
| package agent | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gliderlabs/ssh" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	_ "embed" | ||||
| ) | ||||
| 
 | ||||
| // global configuration
 | ||||
| 
 | ||||
| type AgentState struct { | ||||
| 	startTime time.Time | ||||
| 
 | ||||
| 	// Advance warning time to notify the user of something important happening
 | ||||
| 	advanceWarningTime time.Duration | ||||
| 
 | ||||
| 	// session expiry time
 | ||||
| 	sessionExpiryTime time.Duration | ||||
| 	agentExpriryTime time.Duration | ||||
| 
 | ||||
| 	// ticker
 | ||||
| 	tickerInterval time.Duration | ||||
| @ -35,10 +41,23 @@ type AgentSession struct { | ||||
| 
 | ||||
| var state AgentState | ||||
| 
 | ||||
| func ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval time.Duration) { | ||||
| const holdFilename = ".hold" | ||||
| 
 | ||||
| //go:embed help.txt
 | ||||
| var helpMessage string | ||||
| 
 | ||||
| func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) { | ||||
| 	if fileExists(holdFilename) { | ||||
| 		log.Printf("Removing hold file '%s'", holdFilename) | ||||
| 		err := os.Remove(holdFilename) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Could not remove hold file: '%s'", holdFilename) | ||||
| 		} | ||||
| 	} | ||||
| 	state = AgentState{ | ||||
| 		startTime:          time.Now(), | ||||
| 		advanceWarningTime: advanceWarningTime, | ||||
| 		sessionExpiryTime:  sessionExpiryTime, | ||||
| 		agentExpriryTime:   agentExpiryTime, | ||||
| 		tickerInterval:     tickerInterval, | ||||
| 		ticker:             time.NewTicker(tickerInterval), | ||||
| 		sessions:           make(map[int]*AgentSession), | ||||
| @ -54,6 +73,25 @@ func ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval time.D | ||||
| 
 | ||||
| func Login(sessionId int, sshSession ssh.Session) { | ||||
| 	log.Println("New login") | ||||
| 	hostname, _ := os.Hostname() | ||||
| 
 | ||||
| 	holdFileStats, ok := fileExistsWithStats(holdFilename) | ||||
| 	if ok { | ||||
| 		if holdFileStats.ModTime().After(time.Now()) { | ||||
| 			// modification time in the future, leaving intact
 | ||||
| 			log.Println("Hold file has modification time in the future, leaving intact") | ||||
| 		} else { | ||||
| 			log.Printf("Touching hold file '%s'", holdFilename) | ||||
| 			err := os.Chtimes(holdFilename, time.Now(), time.Now()) | ||||
| 			if err != nil { | ||||
| 				log.Printf("Could not touch hold file: '%s'", holdFilename) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	PrintMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname)) | ||||
| 	PrintHelpMessage(sshSession) | ||||
| 
 | ||||
| 	agentSession := AgentSession{ | ||||
| 		startTime:  time.Now(), | ||||
| 		sshSession: sshSession, | ||||
| @ -62,6 +100,12 @@ func Login(sessionId int, sshSession ssh.Session) { | ||||
| 	LogStatus() | ||||
| } | ||||
| 
 | ||||
| func PrintHelpMessage(sshSession ssh.Session) { | ||||
| 	PrintMessage(sshSession, fmt.Sprintf(helpMessage, | ||||
| 		state.expiryTime(holdFilename).Format(time.DateTime), | ||||
| 		state.agentExpriryTime)) | ||||
| } | ||||
| 
 | ||||
| func LogOut(sessionId int) { | ||||
| 	log.Println("User logged out") | ||||
| 	delete(state.sessions, sessionId) | ||||
| @ -69,19 +113,75 @@ func LogOut(sessionId int) { | ||||
| 	check() | ||||
| } | ||||
| 
 | ||||
| func PrintMessage(sshSession ssh.Session, message string) { | ||||
| 	io.WriteString(sshSession.Stderr(), "\n\r###\n\r") | ||||
| 	for _, line := range strings.Split(message, "\n") { | ||||
| 		io.WriteString(sshSession.Stderr(), "### "+line+"\n\r") | ||||
| 	} | ||||
| 	io.WriteString(sshSession.Stderr(), "\n\r") | ||||
| } | ||||
| 
 | ||||
| func LogStatus() { | ||||
| 	fmt := "%-20s %-20s" | ||||
| 	log.Println() | ||||
| 	log.Printf(fmt, "UID", "START_TIME") | ||||
| 	for uid, session := range state.sessions { | ||||
| 		log.Printf(fmt, strconv.Itoa(uid), session.startTime.Format("2006-01-02 15:04:05")) | ||||
| 		log.Printf(fmt, strconv.Itoa(uid), session.startTime.Format(time.DateTime)) | ||||
| 	} | ||||
| 	log.Println() | ||||
| } | ||||
| 
 | ||||
| func fileExistsWithStats(filename string) (os.FileInfo, bool) { | ||||
| 	stats, err := os.Stat(filename) | ||||
| 	return stats, !os.IsNotExist(err) | ||||
| } | ||||
| 
 | ||||
| func fileExists(filename string) bool { | ||||
| 	_, err := os.Stat(filename) | ||||
| 	return !os.IsNotExist(err) | ||||
| } | ||||
| 
 | ||||
| func (state *AgentState) expiryTime(filename string) time.Time { | ||||
| 	stats, err := os.Stat(filename) | ||||
| 	if err != nil { | ||||
| 		return state.startTime | ||||
| 	} | ||||
| 	return stats.ModTime() | ||||
| } | ||||
| 
 | ||||
| // Behavior to implement
 | ||||
| //  1. there is a global timeout for all agent sessions together: state.agentExpirtyTime
 | ||||
| //  2. The expiry time is relative to the modification time of the .hold file in the
 | ||||
| //     agent directory or, if that file does not exist, the start time of the agent.
 | ||||
| //  3. if we are close to the expiry time then we message users with instruction on
 | ||||
| //     how to prevent the timeout
 | ||||
| //  4. If the last user logs out, the aagent will exit immediately if no .hold file is
 | ||||
| //     present. Otherwise it will exit after the epxiry time. This allows users to
 | ||||
| //     reconnect later.
 | ||||
| func check() { | ||||
| 	log.Println("Timer is firing!") | ||||
| 	for _, session := range state.sessions { | ||||
| 		io.WriteString(session.sshSession.Stderr(), "\n\nThe clock is ticking for you!\n\n") | ||||
| 	now := time.Now() | ||||
| 
 | ||||
| 	expiryTime := state.expiryTime(".hold").Add(state.agentExpriryTime) | ||||
| 
 | ||||
| 	if now.After(expiryTime) { | ||||
| 		messageUsers("Expiry time was reached logging out") | ||||
| 		time.Sleep(5 * time.Second) | ||||
| 		log.Println("Agent exiting") | ||||
| 		os.Exit(0) | ||||
| 	} | ||||
| 
 | ||||
| 	if expiryTime.Sub(now) < state.advanceWarningTime { | ||||
| 		messageUsers( | ||||
| 			fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime))) | ||||
| 		for _, session := range state.sessions { | ||||
| 			PrintHelpMessage(session.sshSession) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func messageUsers(message string) { | ||||
| 	log.Printf("=== Notification to users: %s", message) | ||||
| 	for _, session := range state.sessions { | ||||
| 		PrintMessage(session.sshSession, message) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -65,14 +65,14 @@ func (admin *Admin) logStatus() { | ||||
| 	for _, agent := range admin.agents { | ||||
| 		agent.clientSession.RemoteAddr() | ||||
| 		log.Printf("%-20s %-20s %-20s\n", agent.publicId, | ||||
| 			agent.startTime.Format("2006-01-02 15:04:05"), | ||||
| 			agent.startTime.Format(time.DateTime), | ||||
| 			agent.clientSession.RemoteAddr().String()) | ||||
| 	} | ||||
| 	log.Println("") | ||||
| 	log.Printf("%-20s %-20s %-20s\n", "CLIENT", "ACTIVE_SINCE", "REMOTE_ADDRESS") | ||||
| 	for _, client := range admin.clients { | ||||
| 		log.Printf("%-20s %-20s %-20s", client.publicId, | ||||
| 			client.startTime.Format("2006-01-02 15:04:05"), | ||||
| 			client.startTime.Format(time.DateTime), | ||||
| 			client.client.RemoteAddr()) | ||||
| 	} | ||||
| 	log.Printf("\n") | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user