From 788050df32ff3a96c1c25cf75725a73eb26967e0 Mon Sep 17 00:00:00 2001
From: Erik Brakkee <erik@brakkee.org>
Date: Sat, 27 Jul 2024 22:37:40 +0200
Subject: [PATCH] Unique ids for clients generated by converge server and made
 available to the ssh session through a net.Conn extension that passes the ID
 to the SSH session through the LocalAddr().

---
 cmd/agent/agent.go               |  5 ++---
 pkg/agent/session.go             | 17 ++++++++---------
 pkg/comms/agentlistener.go       | 32 +++++++++++++++++++++++++-------
 pkg/comms/agentserver.go         | 19 +++++++++++++++----
 pkg/comms/events.go              |  2 +-
 pkg/concurrency/atomiccounter.go | 22 ++++++++++++++++++++++
 pkg/converge/admin.go            | 25 ++++++++++++++++++++++---
 7 files changed, 95 insertions(+), 27 deletions(-)
 create mode 100644 pkg/concurrency/atomiccounter.go

diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go
index 53b354a..4887672 100755
--- a/cmd/agent/agent.go
+++ b/cmd/agent/agent.go
@@ -32,7 +32,7 @@ import (
 var hostPrivateKey []byte
 
 func SftpHandler(sess ssh.Session) {
-	uid := int(time.Now().UnixMicro())
+	uid := sess.LocalAddr().String()
 	agent.Login(uid, sess)
 	defer agent.LogOut(uid)
 
@@ -66,12 +66,11 @@ func sshServer(hostKeyFile string, shellCommand string,
 		if err != nil {
 			panic(err)
 		}
-		uid := int(time.Now().UnixMilli())
+		uid := s.LocalAddr().String()
 		agent.Login(uid, s)
 		iowrappers.SynchronizeStreams(process.Pipe(), s)
 		agent.LogOut(uid)
 		process.Wait()
-		process.Wait()
 	})
 
 	log.Println("starting ssh server, waiting for debug sessions")
diff --git a/pkg/agent/session.go b/pkg/agent/session.go
index 2a6cdcf..b39db51 100644
--- a/pkg/agent/session.go
+++ b/pkg/agent/session.go
@@ -12,7 +12,6 @@ import (
 	"os"
 	"path/filepath"
 	"runtime"
-	"strconv"
 	"strings"
 	"text/template"
 	"time"
@@ -49,7 +48,7 @@ type AgentState struct {
 	ticker         *time.Ticker
 
 	// map of unique session id to a session
-	sessions map[int]*AgentSession
+	sessions map[string]*AgentSession
 
 	lastUserLoginTime time.Time
 	agentUsed         bool
@@ -92,7 +91,7 @@ func ConfigureAgent(commChannel comms.CommChannel,
 		lastExpiryTimmeReported: time.Time{},
 		tickerInterval:          tickerInterval,
 		ticker:                  time.NewTicker(tickerInterval),
-		sessions:                make(map[int]*AgentSession),
+		sessions:                make(map[string]*AgentSession),
 
 		lastUserLoginTime: time.Time{},
 		agentUsed:         false,
@@ -122,11 +121,11 @@ func ConfigureAgent(commChannel comms.CommChannel,
 
 }
 
-func Login(sessionId int, sshSession ssh.Session) {
+func Login(sessionId string, sshSession ssh.Session) {
 	events <- async.Async(login, sessionId, sshSession)
 }
 
-func LogOut(sessionId int) {
+func LogOut(sessionId string) {
 	events <- async.Async(logOut, sessionId)
 }
 
@@ -176,7 +175,7 @@ func holdFileMessage() string {
 	return message
 }
 
-func login(sessionId int, sshSession ssh.Session) {
+func login(sessionId string, sshSession ssh.Session) {
 	log.Println("New login")
 	hostname, _ := os.Hostname()
 
@@ -238,7 +237,7 @@ func formatHelpMessage() string {
 	return helpFormatted
 }
 
-func logOut(sessionId int) {
+func logOut(sessionId string) {
 	log.Println("User logged out")
 	delete(state.sessions, sessionId)
 	logStatus()
@@ -255,13 +254,13 @@ func printMessage(sshSession ssh.Session, message string) {
 func logStatus() {
 	fmt := "%-20s %-20s %-20s"
 	log.Println()
-	log.Printf(fmt, "UID", "START_TIME", "TYPE")
+	log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
 	for uid, session := range state.sessions {
 		sessionType := session.sshSession.Subsystem()
 		if sessionType == "" {
 			sessionType = "ssh"
 		}
-		log.Printf(fmt, strconv.Itoa(uid),
+		log.Printf(fmt, uid,
 			session.startTime.Format(time.DateTime),
 			sessionType)
 	}
diff --git a/pkg/comms/agentlistener.go b/pkg/comms/agentlistener.go
index 9c4928d..45ce80e 100644
--- a/pkg/comms/agentlistener.go
+++ b/pkg/comms/agentlistener.go
@@ -1,7 +1,9 @@
 package comms
 
 import (
+	"converge/pkg/websocketutil"
 	"net"
+	"strconv"
 )
 
 type AgentListener struct {
@@ -12,19 +14,35 @@ func NewAgentListener(listener net.Listener) AgentListener {
 	return AgentListener{decorated: listener}
 }
 
+type LocalAddrHackConn struct {
+	net.Conn
+	localAddr net.Addr
+}
+
+func (conn LocalAddrHackConn) LocalAddr() net.Addr {
+	return conn.localAddr
+}
+
+func NewLocalAddrHackConn(conn net.Conn, clientId string) LocalAddrHackConn {
+	addr := LocalAddrHackConn{
+		localAddr: websocketutil.WebSocketAddr(clientId),
+	}
+	addr.Conn = conn
+	return addr
+}
+
 func (listener AgentListener) Accept() (net.Conn, error) {
 	conn, err := listener.decorated.Accept()
 	if err != nil {
 		return nil, err
 	}
 
-	//_, err = CheckProtocolVersion(Agent, conn)
-	//if err != nil {
-	//	conn.Close()
-	//	return nil, err
-	//}
-
-	return conn, nil
+	clientInfo, err := ReceiveClientInfo(conn)
+	if err != nil {
+		conn.Close()
+		return nil, err
+	}
+	return NewLocalAddrHackConn(conn, strconv.Itoa(clientInfo.ClientId)), nil
 }
 
 func (listener AgentListener) Close() error {
diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go
index 593b447..a73e164 100644
--- a/pkg/comms/agentserver.go
+++ b/pkg/comms/agentserver.go
@@ -97,7 +97,6 @@ func ListenForAgentEvents(channel GOBChannel,
 		err := channel.Decoder.Decode(&result)
 
 		if err != nil {
-			// TODO more clean solution, need to explicitly close when agent exits.
 			log.Printf("Exiting agent listener %v", err)
 			return
 		}
@@ -129,7 +128,6 @@ func ListenForServerEvents(channel CommChannel) {
 		err := channel.SideChannel.Decoder.Decode(&result)
 
 		if err != nil {
-			// TODO more clean solution, need to explicitly close when agent exits.
 			log.Printf("Exiting agent listener %v", err)
 			return
 		}
@@ -156,7 +154,7 @@ func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, e
 		return ServerInfo{}, nil
 	}
 	// TODO remove logging
-	log.Println("Server info received: ", serverInfo)
+	log.Println("Agent configuration received from server")
 
 	return serverInfo, err
 }
@@ -185,7 +183,6 @@ func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo,
 // is terminated.
 
 func CheckProtocolVersion(role Role, channel GOBChannel) error {
-	log.Println("ROLE ", role)
 	switch role {
 	case Agent:
 		err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
@@ -259,3 +256,17 @@ func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
 // ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
 // decorates the yamux Session (which is a listener) and uses this connection to exchange some
 // metadata before the connection is handed back to SSH.
+
+func SendClientInfo(conn io.ReadWriter, info ClientInfo) error {
+	channel := NewGOBChannel(conn)
+	return SendWithTimeout(channel, info)
+}
+
+func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) {
+	channel := NewGOBChannel(conn)
+	clientInfo, err := ReceiveWithTimeout[ClientInfo](channel)
+	if err != nil {
+		return ClientInfo{}, err
+	}
+	return clientInfo, nil
+}
diff --git a/pkg/comms/events.go b/pkg/comms/events.go
index dfcc222..83f0c9d 100644
--- a/pkg/comms/events.go
+++ b/pkg/comms/events.go
@@ -24,7 +24,7 @@ type AgentInfo struct {
 }
 
 type ClientInfo struct {
-	ClientId string
+	ClientId int
 }
 
 type SessionInfo struct {
diff --git a/pkg/concurrency/atomiccounter.go b/pkg/concurrency/atomiccounter.go
new file mode 100644
index 0000000..1599a04
--- /dev/null
+++ b/pkg/concurrency/atomiccounter.go
@@ -0,0 +1,22 @@
+package concurrency
+
+import "sync"
+
+type AtomicCounter struct {
+	mutex     sync.Mutex
+	lastValue int
+}
+
+func NewAtomicCounter() *AtomicCounter {
+	return &AtomicCounter{
+		mutex:     sync.Mutex{},
+		lastValue: 0,
+	}
+}
+
+func (counter *AtomicCounter) IncrementAndGet() int {
+	counter.mutex.Lock()
+	defer counter.mutex.Unlock()
+	counter.lastValue++
+	return counter.lastValue
+}
diff --git a/pkg/converge/admin.go b/pkg/converge/admin.go
index dce80b6..ca8ab55 100644
--- a/pkg/converge/admin.go
+++ b/pkg/converge/admin.go
@@ -2,11 +2,13 @@ package converge
 
 import (
 	"converge/pkg/comms"
+	"converge/pkg/concurrency"
 	"converge/pkg/iowrappers"
 	"fmt"
 	"io"
 	"log"
 	"net"
+	"strconv"
 	"sync"
 	"time"
 )
@@ -21,8 +23,11 @@ type Agent struct {
 	expiryTime time.Time
 }
 
+var clientIdGenerator = concurrency.NewAtomicCounter()
+
 type Client struct {
 	publicId    string
+	clientId    int
 	agent       net.Conn
 	client      iowrappers.ReadWriteAddrCloser
 	startTime   time.Time
@@ -42,6 +47,7 @@ func NewClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser,
 	agentConn net.Conn) *Client {
 	return &Client{
 		publicId:  publicId,
+		clientId:  clientIdGenerator.IncrementAndGet(),
 		agent:     agentConn,
 		client:    clientConn,
 		startTime: time.Now(),
@@ -78,10 +84,12 @@ func (admin *Admin) logStatus() {
 			agent.agentInfo.OS)
 	}
 	log.Println("")
-	fmt = "%-20s %-20s %-20s %-20s\n"
-	log.Printf(fmt, "CLIENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
+	fmt = "%-10s %-20s %-20s %-20s %-20s\n"
+	log.Printf(fmt, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
 	for _, client := range admin.clients {
-		log.Printf(fmt, client.publicId,
+		log.Printf(fmt,
+			strconv.Itoa(client.clientId),
+			client.publicId,
 			client.startTime.Format(time.DateTime),
 			client.client.RemoteAddr(),
 			client.sessionType)
@@ -124,8 +132,19 @@ func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAd
 		return nil, err
 	}
 	log.Println("Successful websocket connection to agent")
+	log.Println("Sending connection information to agent")
 
 	client := NewClient(publicId, clientConn, agentConn)
+
+	// Before using this connection for SSH we use it to send client metadata to the
+	// agent
+	err = comms.SendClientInfo(agentConn, comms.ClientInfo{
+		ClientId: client.clientId,
+	})
+	if err != nil {
+		return nil, err
+	}
+
 	admin.clients = append(admin.clients, client)
 	admin.logStatus()
 	return client, nil