refactoring towards being able to send events from Admin to UI (websocket) without exposing connection info but only metadata.

This commit is contained in:
Erik Brakkee 2024-07-30 19:03:21 +02:00
parent 7783ab51a8
commit 816e8d8609
7 changed files with 118 additions and 120 deletions

View File

@ -33,9 +33,12 @@ import (
var hostPrivateKey []byte
func SftpHandler(sess ssh.Session) {
uid := sess.LocalAddr().String()
agent.Login(uid, sess)
defer agent.LogOut(uid)
sessionInfo := comms.NewSessionInfo(
sess.LocalAddr().String(),
"sftp",
)
agent.Login(sessionInfo, sess)
defer agent.LogOut(sessionInfo.ClientId)
debugStream := io.Discard
serverOptions := []sftp.ServerOption{
@ -67,10 +70,12 @@ func sshServer(hostKeyFile string, shellCommand string,
if err != nil {
panic(err)
}
uid := s.LocalAddr().String()
agent.Login(uid, s)
sessionInfo := comms.NewSessionInfo(
s.LocalAddr().String(), "ssh",
)
agent.Login(sessionInfo, s)
iowrappers.SynchronizeStreams(process.Pipe(), s)
agent.LogOut(uid)
agent.LogOut(sessionInfo.ClientId)
// will cause addition goroutines to remmain alive when the SSH
// session is killed. For now acceptable since the agent is a short-lived
// process. Using Kill() here will create defunct processes and in normal

View File

@ -47,7 +47,7 @@ type AgentState struct {
ticker *time.Ticker
// map of unique session id to a session
sessions map[string]*AgentSession
clients map[string]*AgentSession
lastUserLoginTime time.Time
agentUsed bool
@ -90,7 +90,7 @@ func ConfigureAgent(commChannel comms.CommChannel,
lastExpiryTimmeReported: time.Time{},
tickerInterval: tickerInterval,
ticker: time.NewTicker(tickerInterval),
sessions: make(map[string]*AgentSession),
clients: make(map[string]*AgentSession),
lastUserLoginTime: time.Time{},
agentUsed: false,
@ -120,12 +120,12 @@ func ConfigureAgent(commChannel comms.CommChannel,
}
func Login(sessionId string, sshSession ssh.Session) {
events <- async.Async(login, sessionId, sshSession)
func Login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
events <- async.Async(login, sessionInfo, sshSession)
}
func LogOut(sessionId string) {
events <- async.Async(logOut, sessionId)
func LogOut(clientId string) {
events <- async.Async(logOut, clientId)
}
// Internal interface synchronous
@ -174,19 +174,10 @@ func holdFileMessage() string {
return message
}
func login(sessionId string, sshSession ssh.Session) {
func login(sessionInfo comms.SessionInfo, sshSession ssh.Session) {
log.Println("New login")
hostname, _ := os.Hostname()
sessionType := sshSession.Subsystem()
if sessionType == "" {
sessionType = "ssh"
}
comms.Send(state.commChannel.SideChannel,
comms.ConvergeMessage{
Value: comms.NewSessionInfo(sessionType),
})
holdFileStats, ok := fileExistsWithStats(holdFilename)
if ok {
if holdFileStats.ModTime().After(time.Now()) {
@ -205,9 +196,16 @@ func login(sessionId string, sshSession ssh.Session) {
startTime: time.Now(),
sshSession: sshSession,
}
state.sessions[sessionId] = &agentSession
state.clients[sessionInfo.ClientId] = &agentSession
state.lastUserLoginTime = time.Now()
state.agentUsed = true
err := comms.SendWithTimeout(state.commChannel.SideChannel,
comms.ConvergeMessage{Value: sessionInfo})
if err != nil {
log.Printf("Could not send session info to converge server, information on server may be incomplete %v", err)
}
logStatus()
printMessage(sshSession, fmt.Sprintf("You are now on %s\n", hostname))
@ -236,9 +234,9 @@ func formatHelpMessage() string {
return helpFormatted
}
func logOut(sessionId string) {
func logOut(clientId string) {
log.Println("User logged out")
delete(state.sessions, sessionId)
delete(state.clients, clientId)
logStatus()
check()
}
@ -254,7 +252,7 @@ func logStatus() {
fmt := "%-20s %-20s %-20s"
log.Println()
log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
for uid, session := range state.sessions {
for uid, session := range state.clients {
sessionType := session.sshSession.Subsystem()
if sessionType == "" {
sessionType = "ssh"
@ -308,7 +306,7 @@ func holdFileChange() {
}
// Behavior to implement
// 1. there is a global timeout for all agent sessions together: state.agentExpirtyTime
// 1. there is a global timeout for all agent clients together: state.agentExpirtyTime
// 2. The expiry time is relative to the modification time of the .hold file in the
// agent directory or, if that file does not exist, the start time of the agent.
// 3. if we are close to the expiry time then we message users with instruction on
@ -331,12 +329,12 @@ func check() {
if expiryTime.Sub(now) < state.advanceWarningTime {
messageUsers(
fmt.Sprintf("Session will expire at %s", expiryTime.Format(time.DateTime)))
for _, session := range state.sessions {
for _, session := range state.clients {
printHelpMessage(session.sshSession)
}
}
if state.agentUsed && !fileExists(holdFilename) && len(state.sessions) == 0 {
if state.agentUsed && !fileExists(holdFilename) && len(state.clients) == 0 {
log.Printf("All clients disconnected and no '%s' file found, exiting", holdFilename)
os.Exit(0)
}
@ -344,7 +342,7 @@ func check() {
func messageUsers(message string) {
log.Printf("=== Notification to users: %s", message)
for _, session := range state.sessions {
for _, session := range state.clients {
printMessage(session.sshSession, message)
}
}

View File

@ -218,41 +218,6 @@ func CheckProtocolVersion(role Role, channel GOBChannel) error {
}
}
func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
channel := NewGOBChannel(conn)
sends := make(chan bool, 10)
receives := make(chan ProtocolVersion, 10)
errors := make(chan error, 10)
SendAsync(channel, ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
ReceiveAsync(channel, receives, errors)
select {
case <-time.After(MESSAGE_TIMEOUT):
log.Println("PROTOCOLVERSION: timeout")
return fmt.Errorf("Timeout waiting for protocol version")
case err := <-errors:
log.Printf("PROTOCOLVERSION: %v", err)
return err
case protocolVersion := <-receives:
otherVersion := protocolVersion.Version
if PROTOCOL_VERSION != otherVersion {
switch role {
case Agent:
log.Printf("Protocol version mismatch: agent %d, converge server %d",
PROTOCOL_VERSION, otherVersion)
case ConvergeServer:
log.Printf("Protocol version mismatch: agent %d, converge server %d",
otherVersion, PROTOCOL_VERSION)
}
return fmt.Errorf("Protocol version mismatch")
}
log.Printf("PROTOCOLVERSION: %v", protocolVersion.Version)
return nil
}
}
// Session info metadata exchange. These are sent over the SSH connection. The agent embedded
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
// decorates the yamux Session (which is a listener) and uses this connection to exchange some

View File

@ -28,6 +28,8 @@ type ClientInfo struct {
}
type SessionInfo struct {
ClientId string
// "ssh", "sftp"
SessionType string
}
@ -73,8 +75,11 @@ func NewAgentInfo() AgentInfo {
}
}
func NewSessionInfo(sessionType string) SessionInfo {
return SessionInfo{SessionType: sessionType}
func NewSessionInfo(clientId, sessionType string) SessionInfo {
return SessionInfo{
ClientId: clientId,
SessionType: sessionType,
}
}
func NewExpiryTimeUpdate(expiryTime time.Time) ExpiryTimeUpdate {

View File

@ -4,6 +4,7 @@ import (
"converge/pkg/comms"
"converge/pkg/concurrency"
"converge/pkg/iowrappers"
"converge/pkg/models"
"fmt"
"io"
"log"
@ -13,59 +14,56 @@ import (
"time"
)
type Agent struct {
type AgentConnection struct {
models.Agent
// server session
commChannel comms.CommChannel
publicId string
startTime time.Time
agentInfo comms.AgentInfo
expiryTime time.Time
}
var clientIdGenerator = concurrency.NewAtomicCounter()
type Client struct {
publicId string
clientId int
type ClientConnection struct {
models.Client
agent net.Conn
client iowrappers.ReadWriteAddrCloser
startTime time.Time
sessionType string
}
func NewAgent(commChannel comms.CommChannel, publicId string, agentInfo comms.AgentInfo) *Agent {
return &Agent{
func NewAgent(commChannel comms.CommChannel, publicId string, agentInfo comms.AgentInfo) *AgentConnection {
return &AgentConnection{
Agent: models.Agent{
PublicId: publicId,
StartTime: time.Now(),
AgentInfo: agentInfo,
},
commChannel: commChannel,
publicId: publicId,
startTime: time.Now(),
agentInfo: agentInfo,
}
}
func NewClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser,
agentConn net.Conn) *Client {
return &Client{
publicId: publicId,
clientId: clientIdGenerator.IncrementAndGet(),
agentConn net.Conn) *ClientConnection {
return &ClientConnection{
Client: models.Client{
PublicId: publicId,
ClientId: clientIdGenerator.IncrementAndGet(),
StartTime: time.Now(),
},
agent: agentConn,
client: clientConn,
startTime: time.Now(),
}
}
type Admin struct {
// map of public id to agent
mutex sync.Mutex
agents map[string]*Agent
clients []*Client
agents map[string]*AgentConnection
clients []*ClientConnection
}
func NewAdmin() *Admin {
admin := Admin{
mutex: sync.Mutex{},
agents: make(map[string]*Agent),
clients: make([]*Client, 0), // not strictly needed
agents: make(map[string]*AgentConnection),
clients: make([]*ClientConnection, 0), // not strictly needed
}
return &admin
}
@ -76,34 +74,34 @@ func (admin *Admin) logStatus() {
"USER", "HOST", "OS")
for _, agent := range admin.agents {
agent.commChannel.Session.RemoteAddr()
log.Printf(fmt, agent.publicId,
agent.startTime.Format(time.DateTime),
agent.expiryTime.Format(time.DateTime),
agent.agentInfo.Username,
agent.agentInfo.Hostname,
agent.agentInfo.OS)
log.Printf(fmt, agent.PublicId,
agent.StartTime.Format(time.DateTime),
agent.ExpiryTime.Format(time.DateTime),
agent.AgentInfo.Username,
agent.AgentInfo.Hostname,
agent.AgentInfo.OS)
}
log.Println("")
fmt = "%-10s %-20s %-20s %-20s %-20s\n"
log.Printf(fmt, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
for _, client := range admin.clients {
log.Printf(fmt,
strconv.Itoa(client.clientId),
client.publicId,
client.startTime.Format(time.DateTime),
strconv.Itoa(client.ClientId),
client.PublicId,
client.StartTime.Format(time.DateTime),
client.client.RemoteAddr(),
client.sessionType)
client.SessionType)
}
log.Printf("\n")
}
func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io.ReadWriteCloser) (*Agent, error) {
func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io.ReadWriteCloser) (*AgentConnection, error) {
admin.mutex.Lock()
defer admin.mutex.Unlock()
agent := admin.agents[publicId]
if agent != nil {
return nil, fmt.Errorf("A different agent with same publicId '%s' already registered", publicId)
return nil, fmt.Errorf("A different agent with same PublicId '%s' already registered", publicId)
}
commChannel, err := comms.NewCommChannel(comms.ConvergeServer, conn)
@ -117,14 +115,14 @@ func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io
return agent, nil
}
func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser) (*Client, error) {
func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser) (*ClientConnection, error) {
admin.mutex.Lock()
defer admin.mutex.Unlock()
agent := admin.agents[publicId]
if agent == nil {
// we should setup on-demend connections ot agents later.
return nil, fmt.Errorf("No agent found for publicId '%s'", publicId)
return nil, fmt.Errorf("No agent found for PublicId '%s'", publicId)
}
agentConn, err := agent.commChannel.Session.Open()
@ -139,7 +137,7 @@ func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAd
// Before using this connection for SSH we use it to send client metadata to the
// agent
err = comms.SendClientInfo(agentConn, comms.ClientInfo{
ClientId: client.clientId,
ClientId: client.ClientId,
})
if err != nil {
return nil, err
@ -168,18 +166,18 @@ func (admin *Admin) RemoveAgent(publicId string) error {
return nil
}
func (admin *Admin) RemoveClient(client *Client) error {
func (admin *Admin) RemoveClient(client *ClientConnection) error {
admin.mutex.Lock()
defer admin.mutex.Unlock()
log.Printf("Removing client: '%s' created at %s\n", client.publicId,
client.startTime.Format("2006-01-02 15:04:05"))
log.Printf("Removing client: '%d' created at %s\n", client.ClientId,
client.StartTime.Format(time.DateTime))
// try to explicitly close connection to the agent.
_ = client.agent.Close()
_ = client.client.Close()
for i, _client := range admin.clients {
if _client == _client {
if _client.ClientId == client.ClientId {
admin.clients = append(admin.clients[:i], admin.clients[i+1:]...)
break
}
@ -212,26 +210,27 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser,
go func() {
comms.ListenForAgentEvents(agent.commChannel.SideChannel,
func(info comms.AgentInfo) {
agent.agentInfo = info
agent.AgentInfo = info
admin.logStatus()
},
func(session comms.SessionInfo) {
log.Println("Recceived sessioninfo ", session)
for _, client := range admin.clients {
// a bit hacky. There should be at most one client that has an unset session
// Very unlikely for multiple sessions to start at the same point in time.
if client.publicId == agent.publicId && client.sessionType != session.SessionType {
client.sessionType = session.SessionType
if strconv.Itoa(client.ClientId) == session.ClientId {
client.SessionType = session.SessionType
break
}
}
},
func(expiry comms.ExpiryTimeUpdate) {
agent.expiryTime = expiry.ExpiryTime
agent.ExpiryTime = expiry.ExpiryTime
admin.logStatus()
})
}()
go log.Printf("Agent registered: '%s'\n", publicId)
go log.Printf("AgentConnection registered: '%s'\n", publicId)
for !agent.commChannel.Session.IsClosed() {
time.Sleep(250 * time.Millisecond)
}

14
pkg/models/agent.go Normal file
View File

@ -0,0 +1,14 @@
package models
import (
"converge/pkg/comms"
"time"
)
type Agent struct {
PublicId string
StartTime time.Time
AgentInfo comms.AgentInfo
ExpiryTime time.Time
}

12
pkg/models/client.go Normal file
View File

@ -0,0 +1,12 @@
package models
import (
"time"
)
type Client struct {
PublicId string
ClientId int
StartTime time.Time
SessionType string
}