basic session management is now implemented.
This commit is contained in:
		
							parent
							
								
									ed922a235f
								
							
						
					
					
						commit
						e945e7453b
					
				| @ -2,6 +2,7 @@ package main | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
|  | 	"cidebug/pkg/agent" | ||||||
| 	"cidebug/pkg/iowrappers" | 	"cidebug/pkg/iowrappers" | ||||||
| 	"cidebug/pkg/websocketutil" | 	"cidebug/pkg/websocketutil" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @ -12,6 +13,7 @@ import ( | |||||||
| 	"os" | 	"os" | ||||||
| 	"os/exec" | 	"os/exec" | ||||||
| 	"syscall" | 	"syscall" | ||||||
|  | 	"time" | ||||||
| 	"unsafe" | 	"unsafe" | ||||||
| 
 | 
 | ||||||
| 	"github.com/creack/pty" | 	"github.com/creack/pty" | ||||||
| @ -73,11 +75,16 @@ func sshServer(hostKeyFile string) *ssh.Server { | |||||||
| 					setWinsize(f, win.Width, win.Height) | 					setWinsize(f, win.Width, win.Height) | ||||||
| 				} | 				} | ||||||
| 			}() | 			}() | ||||||
|  | 			uid := int(time.Now().UnixMilli()) | ||||||
|  | 			agent.Login(uid, s) | ||||||
|  | 
 | ||||||
| 			go func() { | 			go func() { | ||||||
| 				io.Copy(f, s) // stdin
 | 				io.Copy(f, s) // stdin
 | ||||||
| 			}() | 			}() | ||||||
| 			io.Copy(s, f) // stdout
 | 			io.Copy(s, f) // stdout
 | ||||||
| 			cmd.Wait() | 			cmd.Wait() | ||||||
|  | 			log.Println("User logged out") | ||||||
|  | 			agent.LogOut(uid) | ||||||
| 		} else { | 		} else { | ||||||
| 			io.WriteString(s, "No PTY requested.\n") | 			io.WriteString(s, "No PTY requested.\n") | ||||||
| 			s.Exit(1) | 			s.Exit(1) | ||||||
| @ -145,6 +152,12 @@ func (f ReaderFunc) Read(p []byte) (n int, err error) { | |||||||
| 
 | 
 | ||||||
| func main() { | func main() { | ||||||
| 	wsURL := os.Args[1] | 	wsURL := os.Args[1] | ||||||
|  | 
 | ||||||
|  | 	advanceWarningTime := 1 * time.Minute | ||||||
|  | 	sessionExpiryTime := 5 * time.Minute | ||||||
|  | 	tickerInterval := 10 * time.Second | ||||||
|  | 	agent.ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval) | ||||||
|  | 
 | ||||||
| 	conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) | 	conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Println("WebSocket connection error:", err) | 		log.Println("WebSocket connection error:", err) | ||||||
|  | |||||||
							
								
								
									
										87
									
								
								pkg/agent/session.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								pkg/agent/session.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,87 @@ | |||||||
|  | package agent | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"github.com/gliderlabs/ssh" | ||||||
|  | 	"io" | ||||||
|  | 	"log" | ||||||
|  | 	"strconv" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // global configuration
 | ||||||
|  | 
 | ||||||
|  | type AgentState struct { | ||||||
|  | 
 | ||||||
|  | 	// Advance warning time to notify the user of something important happening
 | ||||||
|  | 	advanceWarningTime time.Duration | ||||||
|  | 
 | ||||||
|  | 	// session expiry time
 | ||||||
|  | 	sessionExpiryTime time.Duration | ||||||
|  | 
 | ||||||
|  | 	// ticker
 | ||||||
|  | 	tickerInterval time.Duration | ||||||
|  | 	ticker         *time.Ticker | ||||||
|  | 
 | ||||||
|  | 	// map of unique session id to a session
 | ||||||
|  | 	sessions map[int]*AgentSession | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AgentSession struct { | ||||||
|  | 	startTime time.Time | ||||||
|  | 
 | ||||||
|  | 	// For sending messages to the user
 | ||||||
|  | 	sshSession ssh.Session | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var state AgentState | ||||||
|  | 
 | ||||||
|  | func ConfigureAgent(advanceWarningTime, sessionExpiryTime, tickerInterval time.Duration) { | ||||||
|  | 	state = AgentState{ | ||||||
|  | 		advanceWarningTime: advanceWarningTime, | ||||||
|  | 		sessionExpiryTime:  sessionExpiryTime, | ||||||
|  | 		tickerInterval:     tickerInterval, | ||||||
|  | 		ticker:             time.NewTicker(tickerInterval), | ||||||
|  | 		sessions:           make(map[int]*AgentSession), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			<-state.ticker.C | ||||||
|  | 			check() | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Login(sessionId int, sshSession ssh.Session) { | ||||||
|  | 	log.Println("New login") | ||||||
|  | 	agentSession := AgentSession{ | ||||||
|  | 		startTime:  time.Now(), | ||||||
|  | 		sshSession: sshSession, | ||||||
|  | 	} | ||||||
|  | 	state.sessions[sessionId] = &agentSession | ||||||
|  | 	LogStatus() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func LogOut(sessionId int) { | ||||||
|  | 	log.Println("User logged out") | ||||||
|  | 	delete(state.sessions, sessionId) | ||||||
|  | 	LogStatus() | ||||||
|  | 	check() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func LogStatus() { | ||||||
|  | 	fmt := "%-20s %-20s" | ||||||
|  | 	log.Println() | ||||||
|  | 	log.Printf(fmt, "UID", "START_TIME") | ||||||
|  | 	for uid, session := range state.sessions { | ||||||
|  | 		log.Printf(fmt, strconv.Itoa(uid), session.startTime.Format("2006-01-02 15:04:05")) | ||||||
|  | 	} | ||||||
|  | 	log.Println() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func check() { | ||||||
|  | 	log.Println("Timer is firing!") | ||||||
|  | 	for _, session := range state.sessions { | ||||||
|  | 		io.WriteString(session.sshSession.Stderr(), "\n\nThe clock is ticking for you!\n\n") | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -1,4 +1,4 @@ | |||||||
| package sshutils | package agent | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| @ -169,7 +169,7 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser) error { | |||||||
| 	}() | 	}() | ||||||
| 	log.Printf("Agent registered: '%s'\n", publicId) | 	log.Printf("Agent registered: '%s'\n", publicId) | ||||||
| 	for !agent.clientSession.IsClosed() { | 	for !agent.clientSession.IsClosed() { | ||||||
| 		time.Sleep(1 * time.Second) | 		time.Sleep(250 * time.Millisecond) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user