From 7a51e3ac454748282758a2eafbde43a31082e9e1 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Sat, 27 Jul 2024 22:37:40 +0200 Subject: [PATCH] Unique ids for clients generated by converge server and made available to the ssh session through a net.Conn extension that passes the ID to the SSH session through the LocalAddr(). --- cmd/agent/agent.go | 5 ++--- pkg/agent/session.go | 17 ++++++++--------- pkg/comms/agentlistener.go | 32 +++++++++++++++++++++++++------- pkg/comms/agentserver.go | 19 +++++++++++++++---- pkg/comms/events.go | 2 +- pkg/concurrency/atomiccounter.go | 22 ++++++++++++++++++++++ pkg/converge/admin.go | 25 ++++++++++++++++++++++--- 7 files changed, 95 insertions(+), 27 deletions(-) create mode 100644 pkg/concurrency/atomiccounter.go diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 53b354a..4887672 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -32,7 +32,7 @@ import ( var hostPrivateKey []byte func SftpHandler(sess ssh.Session) { - uid := int(time.Now().UnixMicro()) + uid := sess.LocalAddr().String() agent.Login(uid, sess) defer agent.LogOut(uid) @@ -66,12 +66,11 @@ func sshServer(hostKeyFile string, shellCommand string, if err != nil { panic(err) } - uid := int(time.Now().UnixMilli()) + uid := s.LocalAddr().String() agent.Login(uid, s) iowrappers.SynchronizeStreams(process.Pipe(), s) agent.LogOut(uid) process.Wait() - process.Wait() }) log.Println("starting ssh server, waiting for debug sessions") diff --git a/pkg/agent/session.go b/pkg/agent/session.go index 2a6cdcf..b39db51 100644 --- a/pkg/agent/session.go +++ b/pkg/agent/session.go @@ -12,7 +12,6 @@ import ( "os" "path/filepath" "runtime" - "strconv" "strings" "text/template" "time" @@ -49,7 +48,7 @@ type AgentState struct { ticker *time.Ticker // map of unique session id to a session - sessions map[int]*AgentSession + sessions map[string]*AgentSession lastUserLoginTime time.Time agentUsed bool @@ -92,7 +91,7 @@ func ConfigureAgent(commChannel comms.CommChannel, lastExpiryTimmeReported: time.Time{}, tickerInterval: tickerInterval, ticker: time.NewTicker(tickerInterval), - sessions: make(map[int]*AgentSession), + sessions: make(map[string]*AgentSession), lastUserLoginTime: time.Time{}, agentUsed: false, @@ -122,11 +121,11 @@ func ConfigureAgent(commChannel comms.CommChannel, } -func Login(sessionId int, sshSession ssh.Session) { +func Login(sessionId string, sshSession ssh.Session) { events <- async.Async(login, sessionId, sshSession) } -func LogOut(sessionId int) { +func LogOut(sessionId string) { events <- async.Async(logOut, sessionId) } @@ -176,7 +175,7 @@ func holdFileMessage() string { return message } -func login(sessionId int, sshSession ssh.Session) { +func login(sessionId string, sshSession ssh.Session) { log.Println("New login") hostname, _ := os.Hostname() @@ -238,7 +237,7 @@ func formatHelpMessage() string { return helpFormatted } -func logOut(sessionId int) { +func logOut(sessionId string) { log.Println("User logged out") delete(state.sessions, sessionId) logStatus() @@ -255,13 +254,13 @@ func printMessage(sshSession ssh.Session, message string) { func logStatus() { fmt := "%-20s %-20s %-20s" log.Println() - log.Printf(fmt, "UID", "START_TIME", "TYPE") + log.Printf(fmt, "CLIENT", "START_TIME", "TYPE") for uid, session := range state.sessions { sessionType := session.sshSession.Subsystem() if sessionType == "" { sessionType = "ssh" } - log.Printf(fmt, strconv.Itoa(uid), + log.Printf(fmt, uid, session.startTime.Format(time.DateTime), sessionType) } diff --git a/pkg/comms/agentlistener.go b/pkg/comms/agentlistener.go index 9c4928d..45ce80e 100644 --- a/pkg/comms/agentlistener.go +++ b/pkg/comms/agentlistener.go @@ -1,7 +1,9 @@ package comms import ( + "converge/pkg/websocketutil" "net" + "strconv" ) type AgentListener struct { @@ -12,19 +14,35 @@ func NewAgentListener(listener net.Listener) AgentListener { return AgentListener{decorated: listener} } +type LocalAddrHackConn struct { + net.Conn + localAddr net.Addr +} + +func (conn LocalAddrHackConn) LocalAddr() net.Addr { + return conn.localAddr +} + +func NewLocalAddrHackConn(conn net.Conn, clientId string) LocalAddrHackConn { + addr := LocalAddrHackConn{ + localAddr: websocketutil.WebSocketAddr(clientId), + } + addr.Conn = conn + return addr +} + func (listener AgentListener) Accept() (net.Conn, error) { conn, err := listener.decorated.Accept() if err != nil { return nil, err } - //_, err = CheckProtocolVersion(Agent, conn) - //if err != nil { - // conn.Close() - // return nil, err - //} - - return conn, nil + clientInfo, err := ReceiveClientInfo(conn) + if err != nil { + conn.Close() + return nil, err + } + return NewLocalAddrHackConn(conn, strconv.Itoa(clientInfo.ClientId)), nil } func (listener AgentListener) Close() error { diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index 593b447..a73e164 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -97,7 +97,6 @@ func ListenForAgentEvents(channel GOBChannel, err := channel.Decoder.Decode(&result) if err != nil { - // TODO more clean solution, need to explicitly close when agent exits. log.Printf("Exiting agent listener %v", err) return } @@ -129,7 +128,6 @@ func ListenForServerEvents(channel CommChannel) { err := channel.SideChannel.Decoder.Decode(&result) if err != nil { - // TODO more clean solution, need to explicitly close when agent exits. log.Printf("Exiting agent listener %v", err) return } @@ -156,7 +154,7 @@ func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, e return ServerInfo{}, nil } // TODO remove logging - log.Println("Server info received: ", serverInfo) + log.Println("Agent configuration received from server") return serverInfo, err } @@ -185,7 +183,6 @@ func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo, // is terminated. func CheckProtocolVersion(role Role, channel GOBChannel) error { - log.Println("ROLE ", role) switch role { case Agent: err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION}) @@ -259,3 +256,17 @@ func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error { // ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that // decorates the yamux Session (which is a listener) and uses this connection to exchange some // metadata before the connection is handed back to SSH. + +func SendClientInfo(conn io.ReadWriter, info ClientInfo) error { + channel := NewGOBChannel(conn) + return SendWithTimeout(channel, info) +} + +func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) { + channel := NewGOBChannel(conn) + clientInfo, err := ReceiveWithTimeout[ClientInfo](channel) + if err != nil { + return ClientInfo{}, err + } + return clientInfo, nil +} diff --git a/pkg/comms/events.go b/pkg/comms/events.go index dfcc222..83f0c9d 100644 --- a/pkg/comms/events.go +++ b/pkg/comms/events.go @@ -24,7 +24,7 @@ type AgentInfo struct { } type ClientInfo struct { - ClientId string + ClientId int } type SessionInfo struct { diff --git a/pkg/concurrency/atomiccounter.go b/pkg/concurrency/atomiccounter.go new file mode 100644 index 0000000..1599a04 --- /dev/null +++ b/pkg/concurrency/atomiccounter.go @@ -0,0 +1,22 @@ +package concurrency + +import "sync" + +type AtomicCounter struct { + mutex sync.Mutex + lastValue int +} + +func NewAtomicCounter() *AtomicCounter { + return &AtomicCounter{ + mutex: sync.Mutex{}, + lastValue: 0, + } +} + +func (counter *AtomicCounter) IncrementAndGet() int { + counter.mutex.Lock() + defer counter.mutex.Unlock() + counter.lastValue++ + return counter.lastValue +} diff --git a/pkg/converge/admin.go b/pkg/converge/admin.go index dce80b6..ca8ab55 100644 --- a/pkg/converge/admin.go +++ b/pkg/converge/admin.go @@ -2,11 +2,13 @@ package converge import ( "converge/pkg/comms" + "converge/pkg/concurrency" "converge/pkg/iowrappers" "fmt" "io" "log" "net" + "strconv" "sync" "time" ) @@ -21,8 +23,11 @@ type Agent struct { expiryTime time.Time } +var clientIdGenerator = concurrency.NewAtomicCounter() + type Client struct { publicId string + clientId int agent net.Conn client iowrappers.ReadWriteAddrCloser startTime time.Time @@ -42,6 +47,7 @@ func NewClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser, agentConn net.Conn) *Client { return &Client{ publicId: publicId, + clientId: clientIdGenerator.IncrementAndGet(), agent: agentConn, client: clientConn, startTime: time.Now(), @@ -78,10 +84,12 @@ func (admin *Admin) logStatus() { agent.agentInfo.OS) } log.Println("") - fmt = "%-20s %-20s %-20s %-20s\n" - log.Printf(fmt, "CLIENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE") + fmt = "%-10s %-20s %-20s %-20s %-20s\n" + log.Printf(fmt, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE") for _, client := range admin.clients { - log.Printf(fmt, client.publicId, + log.Printf(fmt, + strconv.Itoa(client.clientId), + client.publicId, client.startTime.Format(time.DateTime), client.client.RemoteAddr(), client.sessionType) @@ -124,8 +132,19 @@ func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAd return nil, err } log.Println("Successful websocket connection to agent") + log.Println("Sending connection information to agent") client := NewClient(publicId, clientConn, agentConn) + + // Before using this connection for SSH we use it to send client metadata to the + // agent + err = comms.SendClientInfo(agentConn, comms.ClientInfo{ + ClientId: client.clientId, + }) + if err != nil { + return nil, err + } + admin.clients = append(admin.clients, client) admin.logStatus() return client, nil