refactoring towards being able to send events from Admin to UI (websocket) without exposing connection info but only metadata.

This commit is contained in:
Erik Brakkee 2024-07-30 19:03:21 +02:00
parent 5533b04a5e
commit bf5120aa5b
7 changed files with 118 additions and 120 deletions

View File

@ -33,9 +33,12 @@ import (
var hostPrivateKey []byte var hostPrivateKey []byte
func SftpHandler(sess ssh.Session) { func SftpHandler(sess ssh.Session) {
uid := sess.LocalAddr().String() sessionInfo := comms.NewSessionInfo(
agent.Login(uid, sess) sess.LocalAddr().String(),
defer agent.LogOut(uid) "sftp",
)
agent.Login(sessionInfo, sess)
defer agent.LogOut(sessionInfo.ClientId)
debugStream := io.Discard debugStream := io.Discard
serverOptions := []sftp.ServerOption{ serverOptions := []sftp.ServerOption{
@ -67,10 +70,12 @@ func sshServer(hostKeyFile string, shellCommand string,
if err != nil { if err != nil {
panic(err) panic(err)
} }
uid := s.LocalAddr().String() sessionInfo := comms.NewSessionInfo(
agent.Login(uid, s) s.LocalAddr().String(), "ssh",
)
agent.Login(sessionInfo, s)
iowrappers.SynchronizeStreams(process.Pipe(), s) iowrappers.SynchronizeStreams(process.Pipe(), s)
agent.LogOut(uid) agent.LogOut(sessionInfo.ClientId)
// will cause addition goroutines to remmain alive when the SSH // will cause addition goroutines to remmain alive when the SSH
// session is killed. For now acceptable since the agent is a short-lived // session is killed. For now acceptable since the agent is a short-lived
// process. Using Kill() here will create defunct processes and in normal // process. Using Kill() here will create defunct processes and in normal

View File

@ -47,7 +47,7 @@ type AgentState struct {
ticker *time.Ticker ticker *time.Ticker
// map of unique session id to a session // map of unique session id to a session
sessions map[string]*AgentSession clients map[string]*AgentSession
lastUserLoginTime time.Time lastUserLoginTime time.Time
agentUsed bool agentUsed bool
@ -90,7 +90,7 @@ func ConfigureAgent(commChannel comms.CommChannel,
lastExpiryTimmeReported: time.Time{}, lastExpiryTimmeReported: time.Time{},
tickerInterval: tickerInterval, tickerInterval: tickerInterval,
ticker: time.NewTicker(tickerInterval), ticker: time.NewTicker(tickerInterval),
sessions: make(map[string]*AgentSession), clients: make(map[string]*AgentSession),
lastUserLoginTime: time.Time{}, lastUserLoginTime: time.Time{},
agentUsed: false, agentUsed: false,
@ -120,12 +120,12 @@ func ConfigureAgent(commChannel comms.CommChannel,
} }
func Login(sessionId string, sshSession ssh.Session) { func Login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
events <- async.Async(login, sessionId, sshSession) events <- async.Async(login, sessionInfo, sshSession)
} }
func LogOut(sessionId string) { func LogOut(clientId string) {
events <- async.Async(logOut, sessionId) events <- async.Async(logOut, clientId)
} }
// Internal interface synchronous // Internal interface synchronous
@ -174,19 +174,10 @@ func holdFileMessage() string {
return message return message
} }
func login(sessionId string, sshSession ssh.Session) { func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
log.Println("New login") log.Println("New login")
hostname, _ := os.Hostname() hostname, _ := os.Hostname()
sessionType := sshSession.Subsystem()
if sessionType == "" {
sessionType = "ssh"
}
comms.Send(state.commChannel.SideChannel,
comms.ConvergeMessage{
Value: comms.NewSessionInfo(sessionType),
})
holdFileStats, ok := fileExistsWithStats(holdFilename) holdFileStats, ok := fileExistsWithStats(holdFilename)
if ok { if ok {
if holdFileStats.ModTime().After(time.Now()) { if holdFileStats.ModTime().After(time.Now()) {
@ -205,9 +196,16 @@ func login(sessionId string, sshSession ssh.Session) {
startTime: time.Now(), startTime: time.Now(),
sshSession: sshSession, sshSession: sshSession,
} }
state.sessions[sessionId] = &agentSession state.clients[sessionInfo.ClientId] = &agentSession
state.lastUserLoginTime = time.Now() state.lastUserLoginTime = time.Now()
state.agentUsed = true state.agentUsed = true
err := comms.SendWithTimeout(state.commChannel.SideChannel,
comms.ConvergeMessage{Value: sessionInfo})
if err != nil {
log.Printf("Could not send session info to converge server, information on server may be incomplete %v", err)
}
logStatus() logStatus()
printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname)) printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
@ -236,9 +234,9 @@ func formatHelpMessage() string {
return helpFormatted return helpFormatted
} }
func logOut(sessionId string) { func logOut(clientId string) {
log.Println("User logged out") log.Println("User logged out")
delete(state.sessions, sessionId) delete(state.clients, clientId)
logStatus() logStatus()
check() check()
} }
@ -254,7 +252,7 @@ func logStatus() {
fmt := "%-20s %-20s %-20s" fmt := "%-20s %-20s %-20s"
log.Println() log.Println()
log.Printf(fmt, "CLIENT", "START_TIME", "TYPE") log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
for uid, session := range state.sessions { for uid, session := range state.clients {
sessionType := session.sshSession.Subsystem() sessionType := session.sshSession.Subsystem()
if sessionType == "" { if sessionType == "" {
sessionType = "ssh" sessionType = "ssh"
@ -308,7 +306,7 @@ func holdFileChange() {
} }
// Behavior to implement // Behavior to implement
// 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime // 1. there is a global timeout for all agent clients together: state.agentExpirtyTime
// 2. The expiry time is relative to the modification time of the .hold file in the // 2. The expiry time is relative to the modification time of the .hold file in the
// agent directory or, if that file does not exist, the start time of the agent. // agent directory or, if that file does not exist, the start time of the agent.
// 3. if we are close to the expiry time then we message users with instruction on // 3. if we are close to the expiry time then we message users with instruction on
@ -331,12 +329,12 @@ func check() {
if expiryTime.Sub(now) < state.advanceWarningTime { if expiryTime.Sub(now) < state.advanceWarningTime {
messageUsers( messageUsers(
fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime))) fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime)))
for _, session := range state.sessions { for _, session := range state.clients {
printHelpMessage(session.sshSession) printHelpMessage(session.sshSession)
} }
} }
if state.agentUsed && !fileExists(holdFilename) && len(state.sessions) == 0 { if state.agentUsed && !fileExists(holdFilename) && len(state.clients) == 0 {
log.Printf("All clients disconnected and no '%s' file found, exiting", holdFilename) log.Printf("All clients disconnected and no '%s' file found, exiting", holdFilename)
os.Exit(0) os.Exit(0)
} }
@ -344,7 +342,7 @@ func check() {
func messageUsers(message string) { func messageUsers(message string) {
log.Printf("=== Notification to users: %s", message) log.Printf("=== Notification to users: %s", message)
for _, session := range state.sessions { for _, session := range state.clients {
printMessage(session.sshSession, message) printMessage(session.sshSession, message)
} }
} }

View File

@ -218,41 +218,6 @@ func CheckProtocolVersion(role Role, channel GOBChannel) error {
} }
} }
func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
channel := NewGOBChannel(conn)
sends := make(chan bool, 10)
receives := make(chan ProtocolVersion, 10)
errors := make(chan error, 10)
SendAsync(channel, ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
ReceiveAsync(channel, receives, errors)
select {
case <-time.After(MESSAGE_TIMEOUT):
log.Println("PROTOCOLVERSION: timeout")
return fmt.Errorf("Timeout waiting for protocol version")
case err := <-errors:
log.Printf("PROTOCOLVERSION: %v", err)
return err
case protocolVersion := <-receives:
otherVersion := protocolVersion.Version
if PROTOCOL_VERSION != otherVersion {
switch role {
case Agent:
log.Printf("Protocol version mismatch: agent %d, converge server %d",
PROTOCOL_VERSION, otherVersion)
case ConvergeServer:
log.Printf("Protocol version mismatch: agent %d, converge server %d",
otherVersion, PROTOCOL_VERSION)
}
return fmt.Errorf("Protocol version mismatch")
}
log.Printf("PROTOCOLVERSION: %v", protocolVersion.Version)
return nil
}
}
// Session info metadata exchange. These are sent over the SSH connection. The agent embedded // Session info metadata exchange. These are sent over the SSH connection. The agent embedded
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that // ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
// decorates the yamux Session (which is a listener) and uses this connection to exchange some // decorates the yamux Session (which is a listener) and uses this connection to exchange some

View File

@ -28,6 +28,8 @@ type ClientInfo struct {
} }
type SessionInfo struct { type SessionInfo struct {
ClientId string
// "ssh", "sftp" // "ssh", "sftp"
SessionType string SessionType string
} }
@ -73,8 +75,11 @@ func NewAgentInfo() AgentInfo {
} }
} }
func NewSessionInfo(sessionType string) SessionInfo { func NewSessionInfo(clientId, sessionType string) SessionInfo {
return SessionInfo{SessionType: sessionType} return SessionInfo{
ClientId: clientId,
SessionType: sessionType,
}
} }
func NewExpiryTimeUpdate(expiryTime time.Time) ExpiryTimeUpdate { func NewExpiryTimeUpdate(expiryTime time.Time) ExpiryTimeUpdate {

View File

@ -4,6 +4,7 @@ import (
"converge/pkg/comms" "converge/pkg/comms"
"converge/pkg/concurrency" "converge/pkg/concurrency"
"converge/pkg/iowrappers" "converge/pkg/iowrappers"
"converge/pkg/models"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -13,59 +14,56 @@ import (
"time" "time"
) )
type Agent struct { type AgentConnection struct {
models.Agent
// server session // server session
commChannel comms.CommChannel commChannel comms.CommChannel
publicId string
startTime time.Time
agentInfo comms.AgentInfo
expiryTime time.Time
} }
var clientIdGenerator = concurrency.NewAtomicCounter() var clientIdGenerator = concurrency.NewAtomicCounter()
type Client struct { type ClientConnection struct {
publicId string models.Client
clientId int
agent net.Conn agent net.Conn
client iowrappers.ReadWriteAddrCloser client iowrappers.ReadWriteAddrCloser
startTime time.Time
sessionType string
} }
func NewAgent(commChannel comms.CommChannel, publicId string, agentInfo comms.AgentInfo) *Agent { func NewAgent(commChannel comms.CommChannel, publicId string, agentInfo comms.AgentInfo) *AgentConnection {
return &Agent{ return &AgentConnection{
Agent: models.Agent{
PublicId: publicId,
StartTime: time.Now(),
AgentInfo: agentInfo,
},
commChannel: commChannel, commChannel: commChannel,
publicId: publicId,
startTime: time.Now(),
agentInfo: agentInfo,
} }
} }
func NewClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser, func NewClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser,
agentConn net.Conn) *Client { agentConn net.Conn) *ClientConnection {
return &Client{ return &ClientConnection{
publicId: publicId, Client: models.Client{
clientId: clientIdGenerator.IncrementAndGet(), PublicId: publicId,
ClientId: clientIdGenerator.IncrementAndGet(),
StartTime: time.Now(),
},
agent: agentConn, agent: agentConn,
client: clientConn, client: clientConn,
startTime: time.Now(),
} }
} }
type Admin struct { type Admin struct {
// map of public id to agent // map of public id to agent
mutex sync.Mutex mutex sync.Mutex
agents map[string]*Agent agents map[string]*AgentConnection
clients []*Client clients []*ClientConnection
} }
func NewAdmin() *Admin { func NewAdmin() *Admin {
admin := Admin{ admin := Admin{
mutex: sync.Mutex{}, mutex: sync.Mutex{},
agents: make(map[string]*Agent), agents: make(map[string]*AgentConnection),
clients: make([]*Client, 0), // not strictly needed clients: make([]*ClientConnection, 0), // not strictly needed
} }
return &admin return &admin
} }
@ -76,34 +74,34 @@ func (admin *Admin) logStatus() {
"USER", "HOST", "OS") "USER", "HOST", "OS")
for _, agent := range admin.agents { for _, agent := range admin.agents {
agent.commChannel.Session.RemoteAddr() agent.commChannel.Session.RemoteAddr()
log.Printf(fmt, agent.publicId, log.Printf(fmt, agent.PublicId,
agent.startTime.Format(time.DateTime), agent.StartTime.Format(time.DateTime),
agent.expiryTime.Format(time.DateTime), agent.ExpiryTime.Format(time.DateTime),
agent.agentInfo.Username, agent.AgentInfo.Username,
agent.agentInfo.Hostname, agent.AgentInfo.Hostname,
agent.agentInfo.OS) agent.AgentInfo.OS)
} }
log.Println("") log.Println("")
fmt = "%-10s %-20s %-20s %-20s %-20s\n" fmt = "%-10s %-20s %-20s %-20s %-20s\n"
log.Printf(fmt, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE") log.Printf(fmt, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
for _, client := range admin.clients { for _, client := range admin.clients {
log.Printf(fmt, log.Printf(fmt,
strconv.Itoa(client.clientId), strconv.Itoa(client.ClientId),
client.publicId, client.PublicId,
client.startTime.Format(time.DateTime), client.StartTime.Format(time.DateTime),
client.client.RemoteAddr(), client.client.RemoteAddr(),
client.sessionType) client.SessionType)
} }
log.Printf("\n") log.Printf("\n")
} }
func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io.ReadWriteCloser) (*Agent, error) { func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io.ReadWriteCloser) (*AgentConnection, error) {
admin.mutex.Lock() admin.mutex.Lock()
defer admin.mutex.Unlock() defer admin.mutex.Unlock()
agent := admin.agents[publicId] agent := admin.agents[publicId]
if agent != nil { if agent != nil {
return nil, fmt.Errorf("A different agent with same publicId '%s' already registered", publicId) return nil, fmt.Errorf("A different agent with same PublicId '%s' already registered", publicId)
} }
commChannel, err := comms.NewCommChannel(comms.ConvergeServer, conn) commChannel, err := comms.NewCommChannel(comms.ConvergeServer, conn)
@ -117,14 +115,14 @@ func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io
return agent, nil return agent, nil
} }
func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser) (*Client, error) { func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser) (*ClientConnection, error) {
admin.mutex.Lock() admin.mutex.Lock()
defer admin.mutex.Unlock() defer admin.mutex.Unlock()
agent := admin.agents[publicId] agent := admin.agents[publicId]
if agent == nil { if agent == nil {
// we should setup on-demend connections ot agents later. // we should setup on-demend connections ot agents later.
return nil, fmt.Errorf("No agent found for publicId '%s'", publicId) return nil, fmt.Errorf("No agent found for PublicId '%s'", publicId)
} }
agentConn, err := agent.commChannel.Session.Open() agentConn, err := agent.commChannel.Session.Open()
@ -139,7 +137,7 @@ func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAd
// Before using this connection for SSH we use it to send client metadata to the // Before using this connection for SSH we use it to send client metadata to the
// agent // agent
err = comms.SendClientInfo(agentConn, comms.ClientInfo{ err = comms.SendClientInfo(agentConn, comms.ClientInfo{
ClientId: client.clientId, ClientId: client.ClientId,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -168,18 +166,18 @@ func (admin *Admin) RemoveAgent(publicId string) error {
return nil return nil
} }
func (admin *Admin) RemoveClient(client *Client) error { func (admin *Admin) RemoveClient(client *ClientConnection) error {
admin.mutex.Lock() admin.mutex.Lock()
defer admin.mutex.Unlock() defer admin.mutex.Unlock()
log.Printf("Removing client: '%s' created at %s\n", client.publicId, log.Printf("Removing client: '%d' created at %s\n", client.ClientId,
client.startTime.Format("2006-01-02 15:04:05")) client.StartTime.Format(time.DateTime))
// try to explicitly close connection to the agent. // try to explicitly close connection to the agent.
_ = client.agent.Close() _ = client.agent.Close()
_ = client.client.Close() _ = client.client.Close()
for i, _client := range admin.clients { for i, _client := range admin.clients {
if _client == _client { if _client.ClientId == client.ClientId {
admin.clients = append(admin.clients[:i], admin.clients[i+1:]...) admin.clients = append(admin.clients[:i], admin.clients[i+1:]...)
break break
} }
@ -212,26 +210,27 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser,
go func() { go func() {
comms.ListenForAgentEvents(agent.commChannel.SideChannel, comms.ListenForAgentEvents(agent.commChannel.SideChannel,
func(info comms.AgentInfo) { func(info comms.AgentInfo) {
agent.agentInfo = info agent.AgentInfo = info
admin.logStatus() admin.logStatus()
}, },
func(session comms.SessionInfo) { func(session comms.SessionInfo) {
log.Println("Recceived sessioninfo ", session)
for _, client := range admin.clients { for _, client := range admin.clients {
// a bit hacky. There should be at most one client that has an unset session // a bit hacky. There should be at most one client that has an unset session
// Very unlikely for multiple sessions to start at the same point in time. // Very unlikely for multiple sessions to start at the same point in time.
if client.publicId == agent.publicId && client.sessionType != session.SessionType { if strconv.Itoa(client.ClientId) == session.ClientId {
client.sessionType = session.SessionType client.SessionType = session.SessionType
break break
} }
} }
}, },
func(expiry comms.ExpiryTimeUpdate) { func(expiry comms.ExpiryTimeUpdate) {
agent.expiryTime = expiry.ExpiryTime agent.ExpiryTime = expiry.ExpiryTime
admin.logStatus() admin.logStatus()
}) })
}() }()
go log.Printf("Agent registered: '%s'\n", publicId) go log.Printf("AgentConnection registered: '%s'\n", publicId)
for !agent.commChannel.Session.IsClosed() { for !agent.commChannel.Session.IsClosed() {
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
} }

14
pkg/models/agent.go Normal file
View File

@ -0,0 +1,14 @@
package models
import (
"converge/pkg/comms"
"time"
)
type Agent struct {
PublicId string
StartTime time.Time
AgentInfo comms.AgentInfo
ExpiryTime time.Time
}

12
pkg/models/client.go Normal file
View File

@ -0,0 +1,12 @@
package models
import (
"time"
)
type Client struct {
PublicId string
ClientId int
StartTime time.Time
SessionType string
}