Unique ids for clients generated by converge server and made available to the ssh session through a net.Conn extension that passes the ID to the SSH session through the LocalAddr().

This commit is contained in:
Erik Brakkee 2024-07-27 22:37:40 +02:00
parent 9d0675b2f2
commit 788050df32
7 changed files with 95 additions and 27 deletions

View File

@ -32,7 +32,7 @@ import (
var hostPrivateKey []byte var hostPrivateKey []byte
func SftpHandler(sess ssh.Session) { func SftpHandler(sess ssh.Session) {
uid := int(time.Now().UnixMicro()) uid := sess.LocalAddr().String()
agent.Login(uid, sess) agent.Login(uid, sess)
defer agent.LogOut(uid) defer agent.LogOut(uid)
@ -66,12 +66,11 @@ func sshServer(hostKeyFile string, shellCommand string,
if err != nil { if err != nil {
panic(err) panic(err)
} }
uid := int(time.Now().UnixMilli()) uid := s.LocalAddr().String()
agent.Login(uid, s) agent.Login(uid, s)
iowrappers.SynchronizeStreams(process.Pipe(), s) iowrappers.SynchronizeStreams(process.Pipe(), s)
agent.LogOut(uid) agent.LogOut(uid)
process.Wait() process.Wait()
process.Wait()
}) })
log.Println("starting ssh server, waiting for debug sessions") log.Println("starting ssh server, waiting for debug sessions")

View File

@ -12,7 +12,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"text/template" "text/template"
"time" "time"
@ -49,7 +48,7 @@ type AgentState struct {
ticker *time.Ticker ticker *time.Ticker
// map of unique session id to a session // map of unique session id to a session
sessions map[int]*AgentSession sessions map[string]*AgentSession
lastUserLoginTime time.Time lastUserLoginTime time.Time
agentUsed bool agentUsed bool
@ -92,7 +91,7 @@ func ConfigureAgent(commChannel comms.CommChannel,
lastExpiryTimmeReported: time.Time{}, lastExpiryTimmeReported: time.Time{},
tickerInterval: tickerInterval, tickerInterval: tickerInterval,
ticker: time.NewTicker(tickerInterval), ticker: time.NewTicker(tickerInterval),
sessions: make(map[int]*AgentSession), sessions: make(map[string]*AgentSession),
lastUserLoginTime: time.Time{}, lastUserLoginTime: time.Time{},
agentUsed: false, agentUsed: false,
@ -122,11 +121,11 @@ func ConfigureAgent(commChannel comms.CommChannel,
} }
func Login(sessionId int, sshSession ssh.Session) { func Login(sessionId string, sshSession ssh.Session) {
events <- async.Async(login, sessionId, sshSession) events <- async.Async(login, sessionId, sshSession)
} }
func LogOut(sessionId int) { func LogOut(sessionId string) {
events <- async.Async(logOut, sessionId) events <- async.Async(logOut, sessionId)
} }
@ -176,7 +175,7 @@ func holdFileMessage() string {
return message return message
} }
func login(sessionId int, sshSession ssh.Session) { func login(sessionId string, sshSession ssh.Session) {
log.Println("New login") log.Println("New login")
hostname, _ := os.Hostname() hostname, _ := os.Hostname()
@ -238,7 +237,7 @@ func formatHelpMessage() string {
return helpFormatted return helpFormatted
} }
func logOut(sessionId int) { func logOut(sessionId string) {
log.Println("User logged out") log.Println("User logged out")
delete(state.sessions, sessionId) delete(state.sessions, sessionId)
logStatus() logStatus()
@ -255,13 +254,13 @@ func printMessage(sshSession ssh.Session, message string) {
func logStatus() { func logStatus() {
fmt := "%-20s %-20s %-20s" fmt := "%-20s %-20s %-20s"
log.Println() log.Println()
log.Printf(fmt, "UID", "START_TIME", "TYPE") log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
for uid, session := range state.sessions { for uid, session := range state.sessions {
sessionType := session.sshSession.Subsystem() sessionType := session.sshSession.Subsystem()
if sessionType == "" { if sessionType == "" {
sessionType = "ssh" sessionType = "ssh"
} }
log.Printf(fmt, strconv.Itoa(uid), log.Printf(fmt, uid,
session.startTime.Format(time.DateTime), session.startTime.Format(time.DateTime),
sessionType) sessionType)
} }

View File

@ -1,7 +1,9 @@
package comms package comms
import ( import (
"converge/pkg/websocketutil"
"net" "net"
"strconv"
) )
type AgentListener struct { type AgentListener struct {
@ -12,19 +14,35 @@ func NewAgentListener(listener net.Listener) AgentListener {
return AgentListener{decorated: listener} return AgentListener{decorated: listener}
} }
type LocalAddrHackConn struct {
net.Conn
localAddr net.Addr
}
func (conn LocalAddrHackConn) LocalAddr() net.Addr {
return conn.localAddr
}
func NewLocalAddrHackConn(conn net.Conn, clientId string) LocalAddrHackConn {
addr := LocalAddrHackConn{
localAddr: websocketutil.WebSocketAddr(clientId),
}
addr.Conn = conn
return addr
}
func (listener AgentListener) Accept() (net.Conn, error) { func (listener AgentListener) Accept() (net.Conn, error) {
conn, err := listener.decorated.Accept() conn, err := listener.decorated.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
} }
//_, err = CheckProtocolVersion(Agent, conn) clientInfo, err := ReceiveClientInfo(conn)
//if err != nil { if err != nil {
// conn.Close() conn.Close()
// return nil, err return nil, err
//} }
return NewLocalAddrHackConn(conn, strconv.Itoa(clientInfo.ClientId)), nil
return conn, nil
} }
func (listener AgentListener) Close() error { func (listener AgentListener) Close() error {

View File

@ -97,7 +97,6 @@ func ListenForAgentEvents(channel GOBChannel,
err := channel.Decoder.Decode(&result) err := channel.Decoder.Decode(&result)
if err != nil { if err != nil {
// TODO more clean solution, need to explicitly close when agent exits.
log.Printf("Exiting agent listener %v", err) log.Printf("Exiting agent listener %v", err)
return return
} }
@ -129,7 +128,6 @@ func ListenForServerEvents(channel CommChannel) {
err := channel.SideChannel.Decoder.Decode(&result) err := channel.SideChannel.Decoder.Decode(&result)
if err != nil { if err != nil {
// TODO more clean solution, need to explicitly close when agent exits.
log.Printf("Exiting agent listener %v", err) log.Printf("Exiting agent listener %v", err)
return return
} }
@ -156,7 +154,7 @@ func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, e
return ServerInfo{}, nil return ServerInfo{}, nil
} }
// TODO remove logging // TODO remove logging
log.Println("Server info received: ", serverInfo) log.Println("Agent configuration received from server")
return serverInfo, err return serverInfo, err
} }
@ -185,7 +183,6 @@ func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo,
// is terminated. // is terminated.
func CheckProtocolVersion(role Role, channel GOBChannel) error { func CheckProtocolVersion(role Role, channel GOBChannel) error {
log.Println("ROLE ", role)
switch role { switch role {
case Agent: case Agent:
err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION}) err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
@ -259,3 +256,17 @@ func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that // ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
// decorates the yamux Session (which is a listener) and uses this connection to exchange some // decorates the yamux Session (which is a listener) and uses this connection to exchange some
// metadata before the connection is handed back to SSH. // metadata before the connection is handed back to SSH.
func SendClientInfo(conn io.ReadWriter, info ClientInfo) error {
channel := NewGOBChannel(conn)
return SendWithTimeout(channel, info)
}
func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) {
channel := NewGOBChannel(conn)
clientInfo, err := ReceiveWithTimeout[ClientInfo](channel)
if err != nil {
return ClientInfo{}, err
}
return clientInfo, nil
}

View File

@ -24,7 +24,7 @@ type AgentInfo struct {
} }
type ClientInfo struct { type ClientInfo struct {
ClientId string ClientId int
} }
type SessionInfo struct { type SessionInfo struct {

View File

@ -0,0 +1,22 @@
package concurrency
import "sync"
type AtomicCounter struct {
mutex sync.Mutex
lastValue int
}
func NewAtomicCounter() *AtomicCounter {
return &AtomicCounter{
mutex: sync.Mutex{},
lastValue: 0,
}
}
func (counter *AtomicCounter) IncrementAndGet() int {
counter.mutex.Lock()
defer counter.mutex.Unlock()
counter.lastValue++
return counter.lastValue
}

View File

@ -2,11 +2,13 @@ package converge
import ( import (
"converge/pkg/comms" "converge/pkg/comms"
"converge/pkg/concurrency"
"converge/pkg/iowrappers" "converge/pkg/iowrappers"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"strconv"
"sync" "sync"
"time" "time"
) )
@ -21,8 +23,11 @@ type Agent struct {
expiryTime time.Time expiryTime time.Time
} }
var clientIdGenerator = concurrency.NewAtomicCounter()
type Client struct { type Client struct {
publicId string publicId string
clientId int
agent net.Conn agent net.Conn
client iowrappers.ReadWriteAddrCloser client iowrappers.ReadWriteAddrCloser
startTime time.Time startTime time.Time
@ -42,6 +47,7 @@ func NewClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser,
agentConn net.Conn) *Client { agentConn net.Conn) *Client {
return &Client{ return &Client{
publicId: publicId, publicId: publicId,
clientId: clientIdGenerator.IncrementAndGet(),
agent: agentConn, agent: agentConn,
client: clientConn, client: clientConn,
startTime: time.Now(), startTime: time.Now(),
@ -78,10 +84,12 @@ func (admin *Admin) logStatus() {
agent.agentInfo.OS) agent.agentInfo.OS)
} }
log.Println("") log.Println("")
fmt = "%-20s %-20s %-20s %-20s\n" fmt = "%-10s %-20s %-20s %-20s %-20s\n"
log.Printf(fmt, "CLIENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE") log.Printf(fmt, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
for _, client := range admin.clients { for _, client := range admin.clients {
log.Printf(fmt, client.publicId, log.Printf(fmt,
strconv.Itoa(client.clientId),
client.publicId,
client.startTime.Format(time.DateTime), client.startTime.Format(time.DateTime),
client.client.RemoteAddr(), client.client.RemoteAddr(),
client.sessionType) client.sessionType)
@ -124,8 +132,19 @@ func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAd
return nil, err return nil, err
} }
log.Println("Successful websocket connection to agent") log.Println("Successful websocket connection to agent")
log.Println("Sending connection information to agent")
client := NewClient(publicId, clientConn, agentConn) client := NewClient(publicId, clientConn, agentConn)
// Before using this connection for SSH we use it to send client metadata to the
// agent
err = comms.SendClientInfo(agentConn, comms.ClientInfo{
ClientId: client.clientId,
})
if err != nil {
return nil, err
}
admin.clients = append(admin.clients, client) admin.clients = append(admin.clients, client)
admin.logStatus() admin.logStatus()
return client, nil return client, nil