From bf5120aa5b5be248ff737af39ae572e518d59b0d Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Tue, 30 Jul 2024 19:03:21 +0200 Subject: [PATCH] refactoring towards being able to send events from Admin to UI (websocket) without exposing connection info but only metadata. --- cmd/agent/agent.go | 17 ++++--- pkg/agent/session.go | 46 ++++++++--------- pkg/comms/agentserver.go | 35 ------------- pkg/comms/events.go | 9 +++- pkg/converge/admin.go | 105 +++++++++++++++++++-------------------- pkg/models/agent.go | 14 ++++++ pkg/models/client.go | 12 +++++ 7 files changed, 118 insertions(+), 120 deletions(-) create mode 100644 pkg/models/agent.go create mode 100644 pkg/models/client.go diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 74ceb34..4a03149 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -33,9 +33,12 @@ import ( var hostPrivateKey []byte func SftpHandler(sess ssh.Session) { - uid := sess.LocalAddr().String() - agent.Login(uid, sess) - defer agent.LogOut(uid) + sessionInfo := comms.NewSessionInfo( + sess.LocalAddr().String(), + "sftp", + ) + agent.Login(sessionInfo, sess) + defer agent.LogOut(sessionInfo.ClientId) debugStream := io.Discard serverOptions := []sftp.ServerOption{ @@ -67,10 +70,12 @@ func sshServer(hostKeyFile string, shellCommand string, if err != nil { panic(err) } - uid := s.LocalAddr().String() - agent.Login(uid, s) + sessionInfo := comms.NewSessionInfo( + s.LocalAddr().String(), "ssh", + ) + agent.Login(sessionInfo, s) iowrappers.SynchronizeStreams(process.Pipe(), s) - agent.LogOut(uid) + agent.LogOut(sessionInfo.ClientId) // will cause addition goroutines to remmain alive when the SSH // session is killed. For now acceptable since the agent is a short-lived // process. Using Kill() here will create defunct processes and in normal diff --git a/pkg/agent/session.go b/pkg/agent/session.go index 10774c2..4c0ec03 100644 --- a/pkg/agent/session.go +++ b/pkg/agent/session.go @@ -47,7 +47,7 @@ type AgentState struct { ticker *time.Ticker // map of unique session id to a session - sessions map[string]*AgentSession + clients map[string]*AgentSession lastUserLoginTime time.Time agentUsed bool @@ -90,7 +90,7 @@ func ConfigureAgent(commChannel comms.CommChannel, lastExpiryTimmeReported: time.Time{}, tickerInterval: tickerInterval, ticker: time.NewTicker(tickerInterval), - sessions: make(map[string]*AgentSession), + clients: make(map[string]*AgentSession), lastUserLoginTime: time.Time{}, agentUsed: false, @@ -120,12 +120,12 @@ func ConfigureAgent(commChannel comms.CommChannel, } -func Login(sessionId string, sshSession ssh.Session) { - events <- async.Async(login, sessionId, sshSession) +func Login(sessionInfo comms.SessionInfo, sshSession ssh.Session) { + events <- async.Async(login, sessionInfo, sshSession) } -func LogOut(sessionId string) { - events <- async.Async(logOut, sessionId) +func LogOut(clientId string) { + events <- async.Async(logOut, clientId) } // Internal interface synchronous @@ -174,19 +174,10 @@ func holdFileMessage() string { return message } -func login(sessionId string, sshSession ssh.Session) { +func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) { log.Println("New login") hostname, _ := os.Hostname() - sessionType := sshSession.Subsystem() - if sessionType == "" { - sessionType = "ssh" - } - comms.Send(state.commChannel.SideChannel, - comms.ConvergeMessage{ - Value: comms.NewSessionInfo(sessionType), - }) - holdFileStats, ok := fileExistsWithStats(holdFilename) if ok { if holdFileStats.ModTime().After(time.Now()) { @@ -205,9 +196,16 @@ func login(sessionId string, sshSession ssh.Session) { startTime: time.Now(), sshSession: sshSession, } - state.sessions[sessionId] = &agentSession + state.clients[sessionInfo.ClientId] = &agentSession state.lastUserLoginTime = time.Now() state.agentUsed = true + + err := comms.SendWithTimeout(state.commChannel.SideChannel, + comms.ConvergeMessage{Value: sessionInfo}) + if err != nil { + log.Printf("Could not send session info to converge server, information on server may be incomplete %v", err) + } + logStatus() printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname)) @@ -236,9 +234,9 @@ func formatHelpMessage() string { return helpFormatted } -func logOut(sessionId string) { +func logOut(clientId string) { log.Println("User logged out") - delete(state.sessions, sessionId) + delete(state.clients, clientId) logStatus() check() } @@ -254,7 +252,7 @@ func logStatus() { fmt := "%-20s %-20s %-20s" log.Println() log.Printf(fmt, "CLIENT", "START_TIME", "TYPE") - for uid, session := range state.sessions { + for uid, session := range state.clients { sessionType := session.sshSession.Subsystem() if sessionType == "" { sessionType = "ssh" @@ -308,7 +306,7 @@ func holdFileChange() { } // Behavior to implement -// 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime +// 1. there is a global timeout for all agent clients together: state.agentExpirtyTime // 2. The expiry time is relative to the modification time of the .hold file in the // agent directory or, if that file does not exist, the start time of the agent. // 3. if we are close to the expiry time then we message users with instruction on @@ -331,12 +329,12 @@ func check() { if expiryTime.Sub(now) < state.advanceWarningTime { messageUsers( fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime))) - for _, session := range state.sessions { + for _, session := range state.clients { printHelpMessage(session.sshSession) } } - if state.agentUsed && !fileExists(holdFilename) && len(state.sessions) == 0 { + if state.agentUsed && !fileExists(holdFilename) && len(state.clients) == 0 { log.Printf("All clients disconnected and no '%s' file found, exiting", holdFilename) os.Exit(0) } @@ -344,7 +342,7 @@ func check() { func messageUsers(message string) { log.Printf("=== Notification to users: %s", message) - for _, session := range state.sessions { + for _, session := range state.clients { printMessage(session.sshSession, message) } } diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index a1ebca8..10bf612 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -218,41 +218,6 @@ func CheckProtocolVersion(role Role, channel GOBChannel) error { } } -func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error { - channel := NewGOBChannel(conn) - - sends := make(chan bool, 10) - receives := make(chan ProtocolVersion, 10) - errors := make(chan error, 10) - - SendAsync(channel, ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors) - ReceiveAsync(channel, receives, errors) - - select { - case <-time.After(MESSAGE_TIMEOUT): - log.Println("PROTOCOLVERSION: timeout") - return fmt.Errorf("Timeout waiting for protocol version") - case err := <-errors: - log.Printf("PROTOCOLVERSION: %v", err) - return err - case protocolVersion := <-receives: - otherVersion := protocolVersion.Version - if PROTOCOL_VERSION != otherVersion { - switch role { - case Agent: - log.Printf("Protocol version mismatch: agent %d, converge server %d", - PROTOCOL_VERSION, otherVersion) - case ConvergeServer: - log.Printf("Protocol version mismatch: agent %d, converge server %d", - otherVersion, PROTOCOL_VERSION) - } - return fmt.Errorf("Protocol version mismatch") - } - log.Printf("PROTOCOLVERSION: %v", protocolVersion.Version) - return nil - } -} - // Session info metadata exchange. These are sent over the SSH connection. The agent embedded // ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that // decorates the yamux Session (which is a listener) and uses this connection to exchange some diff --git a/pkg/comms/events.go b/pkg/comms/events.go index 83f0c9d..c3d50ff 100644 --- a/pkg/comms/events.go +++ b/pkg/comms/events.go @@ -28,6 +28,8 @@ type ClientInfo struct { } type SessionInfo struct { + ClientId string + // "ssh", "sftp" SessionType string } @@ -73,8 +75,11 @@ func NewAgentInfo() AgentInfo { } } -func NewSessionInfo(sessionType string) SessionInfo { - return SessionInfo{SessionType: sessionType} +func NewSessionInfo(clientId, sessionType string) SessionInfo { + return SessionInfo{ + ClientId: clientId, + SessionType: sessionType, + } } func NewExpiryTimeUpdate(expiryTime time.Time) ExpiryTimeUpdate { diff --git a/pkg/converge/admin.go b/pkg/converge/admin.go index ca8ab55..faf1000 100644 --- a/pkg/converge/admin.go +++ b/pkg/converge/admin.go @@ -4,6 +4,7 @@ import ( "converge/pkg/comms" "converge/pkg/concurrency" "converge/pkg/iowrappers" + "converge/pkg/models" "fmt" "io" "log" @@ -13,59 +14,56 @@ import ( "time" ) -type Agent struct { +type AgentConnection struct { + models.Agent // server session commChannel comms.CommChannel - publicId string - startTime time.Time - - agentInfo comms.AgentInfo - expiryTime time.Time } var clientIdGenerator = concurrency.NewAtomicCounter() -type Client struct { - publicId string - clientId int - agent net.Conn - client iowrappers.ReadWriteAddrCloser - startTime time.Time - sessionType string +type ClientConnection struct { + models.Client + agent net.Conn + client iowrappers.ReadWriteAddrCloser } -func NewAgent(commChannel comms.CommChannel, publicId string, agentInfo comms.AgentInfo) *Agent { - return &Agent{ +func NewAgent(commChannel comms.CommChannel, publicId string, agentInfo comms.AgentInfo) *AgentConnection { + return &AgentConnection{ + Agent: models.Agent{ + PublicId: publicId, + StartTime: time.Now(), + AgentInfo: agentInfo, + }, commChannel: commChannel, - publicId: publicId, - startTime: time.Now(), - agentInfo: agentInfo, } } func NewClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser, - agentConn net.Conn) *Client { - return &Client{ - publicId: publicId, - clientId: clientIdGenerator.IncrementAndGet(), - agent: agentConn, - client: clientConn, - startTime: time.Now(), + agentConn net.Conn) *ClientConnection { + return &ClientConnection{ + Client: models.Client{ + PublicId: publicId, + ClientId: clientIdGenerator.IncrementAndGet(), + StartTime: time.Now(), + }, + agent: agentConn, + client: clientConn, } } type Admin struct { // map of public id to agent mutex sync.Mutex - agents map[string]*Agent - clients []*Client + agents map[string]*AgentConnection + clients []*ClientConnection } func NewAdmin() *Admin { admin := Admin{ mutex: sync.Mutex{}, - agents: make(map[string]*Agent), - clients: make([]*Client, 0), // not strictly needed + agents: make(map[string]*AgentConnection), + clients: make([]*ClientConnection, 0), // not strictly needed } return &admin } @@ -76,34 +74,34 @@ func (admin *Admin) logStatus() { "USER", "HOST", "OS") for _, agent := range admin.agents { agent.commChannel.Session.RemoteAddr() - log.Printf(fmt, agent.publicId, - agent.startTime.Format(time.DateTime), - agent.expiryTime.Format(time.DateTime), - agent.agentInfo.Username, - agent.agentInfo.Hostname, - agent.agentInfo.OS) + log.Printf(fmt, agent.PublicId, + agent.StartTime.Format(time.DateTime), + agent.ExpiryTime.Format(time.DateTime), + agent.AgentInfo.Username, + agent.AgentInfo.Hostname, + agent.AgentInfo.OS) } log.Println("") fmt = "%-10s %-20s %-20s %-20s %-20s\n" log.Printf(fmt, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE") for _, client := range admin.clients { log.Printf(fmt, - strconv.Itoa(client.clientId), - client.publicId, - client.startTime.Format(time.DateTime), + strconv.Itoa(client.ClientId), + client.PublicId, + client.StartTime.Format(time.DateTime), client.client.RemoteAddr(), - client.sessionType) + client.SessionType) } log.Printf("\n") } -func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io.ReadWriteCloser) (*Agent, error) { +func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io.ReadWriteCloser) (*AgentConnection, error) { admin.mutex.Lock() defer admin.mutex.Unlock() agent := admin.agents[publicId] if agent != nil { - return nil, fmt.Errorf("A different agent with same publicId '%s' already registered", publicId) + return nil, fmt.Errorf("A different agent with same PublicId '%s' already registered", publicId) } commChannel, err := comms.NewCommChannel(comms.ConvergeServer, conn) @@ -117,14 +115,14 @@ func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io return agent, nil } -func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser) (*Client, error) { +func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser) (*ClientConnection, error) { admin.mutex.Lock() defer admin.mutex.Unlock() agent := admin.agents[publicId] if agent == nil { // we should setup on-demend connections ot agents later. - return nil, fmt.Errorf("No agent found for publicId '%s'", publicId) + return nil, fmt.Errorf("No agent found for PublicId '%s'", publicId) } agentConn, err := agent.commChannel.Session.Open() @@ -139,7 +137,7 @@ func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAd // Before using this connection for SSH we use it to send client metadata to the // agent err = comms.SendClientInfo(agentConn, comms.ClientInfo{ - ClientId: client.clientId, + ClientId: client.ClientId, }) if err != nil { return nil, err @@ -168,18 +166,18 @@ func (admin *Admin) RemoveAgent(publicId string) error { return nil } -func (admin *Admin) RemoveClient(client *Client) error { +func (admin *Admin) RemoveClient(client *ClientConnection) error { admin.mutex.Lock() defer admin.mutex.Unlock() - log.Printf("Removing client: '%s' created at %s\n", client.publicId, - client.startTime.Format("2006-01-02 15:04:05")) + log.Printf("Removing client: '%d' created at %s\n", client.ClientId, + client.StartTime.Format(time.DateTime)) // try to explicitly close connection to the agent. _ = client.agent.Close() _ = client.client.Close() for i, _client := range admin.clients { - if _client == _client { + if _client.ClientId == client.ClientId { admin.clients = append(admin.clients[:i], admin.clients[i+1:]...) break } @@ -212,26 +210,27 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser, go func() { comms.ListenForAgentEvents(agent.commChannel.SideChannel, func(info comms.AgentInfo) { - agent.agentInfo = info + agent.AgentInfo = info admin.logStatus() }, func(session comms.SessionInfo) { + log.Println("Recceived sessioninfo ", session) for _, client := range admin.clients { // a bit hacky. There should be at most one client that has an unset session // Very unlikely for multiple sessions to start at the same point in time. - if client.publicId == agent.publicId && client.sessionType != session.SessionType { - client.sessionType = session.SessionType + if strconv.Itoa(client.ClientId) == session.ClientId { + client.SessionType = session.SessionType break } } }, func(expiry comms.ExpiryTimeUpdate) { - agent.expiryTime = expiry.ExpiryTime + agent.ExpiryTime = expiry.ExpiryTime admin.logStatus() }) }() - go log.Printf("Agent registered: '%s'\n", publicId) + go log.Printf("AgentConnection registered: '%s'\n", publicId) for !agent.commChannel.Session.IsClosed() { time.Sleep(250 * time.Millisecond) } diff --git a/pkg/models/agent.go b/pkg/models/agent.go new file mode 100644 index 0000000..5660039 --- /dev/null +++ b/pkg/models/agent.go @@ -0,0 +1,14 @@ +package models + +import ( + "converge/pkg/comms" + "time" +) + +type Agent struct { + PublicId string + StartTime time.Time + + AgentInfo comms.AgentInfo + ExpiryTime time.Time +} diff --git a/pkg/models/client.go b/pkg/models/client.go new file mode 100644 index 0000000..17136ea --- /dev/null +++ b/pkg/models/client.go @@ -0,0 +1,12 @@ +package models + +import ( + "time" +) + +type Client struct { + PublicId string + ClientId int + StartTime time.Time + SessionType string +}