Unique ids for clients generated by converge server and made available to the ssh session through a net.Conn extension that passes the ID to the SSH session through the LocalAddr().

This commit is contained in:
Erik Brakkee 2024-07-27 22:37:40 +02:00
parent 9d0675b2f2
commit 788050df32
7 changed files with 95 additions and 27 deletions

View File

@ -32,7 +32,7 @@ import (
var hostPrivateKey []byte
func SftpHandler(sess ssh.Session) {
uid := int(time.Now().UnixMicro())
uid := sess.LocalAddr().String()
agent.Login(uid, sess)
defer agent.LogOut(uid)
@ -66,12 +66,11 @@ func sshServer(hostKeyFile string, shellCommand string,
if err != nil {
panic(err)
}
uid := int(time.Now().UnixMilli())
uid := s.LocalAddr().String()
agent.Login(uid, s)
iowrappers.SynchronizeStreams(process.Pipe(), s)
agent.LogOut(uid)
process.Wait()
process.Wait()
})
log.Println("starting ssh server, waiting for debug sessions")

View File

@ -12,7 +12,6 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"text/template"
"time"
@ -49,7 +48,7 @@ type AgentState struct {
ticker *time.Ticker
// map of unique session id to a session
sessions map[int]*AgentSession
sessions map[string]*AgentSession
lastUserLoginTime time.Time
agentUsed bool
@ -92,7 +91,7 @@ func ConfigureAgent(commChannel comms.CommChannel,
lastExpiryTimmeReported: time.Time{},
tickerInterval: tickerInterval,
ticker: time.NewTicker(tickerInterval),
sessions: make(map[int]*AgentSession),
sessions: make(map[string]*AgentSession),
lastUserLoginTime: time.Time{},
agentUsed: false,
@ -122,11 +121,11 @@ func ConfigureAgent(commChannel comms.CommChannel,
}
func Login(sessionId int, sshSession ssh.Session) {
func Login(sessionId string, sshSession ssh.Session) {
events <- async.Async(login, sessionId, sshSession)
}
func LogOut(sessionId int) {
func LogOut(sessionId string) {
events <- async.Async(logOut, sessionId)
}
@ -176,7 +175,7 @@ func holdFileMessage() string {
return message
}
func login(sessionId int, sshSession ssh.Session) {
func login(sessionId string, sshSession ssh.Session) {
log.Println("New login")
hostname, _ := os.Hostname()
@ -238,7 +237,7 @@ func formatHelpMessage() string {
return helpFormatted
}
func logOut(sessionId int) {
func logOut(sessionId string) {
log.Println("User logged out")
delete(state.sessions, sessionId)
logStatus()
@ -255,13 +254,13 @@ func printMessage(sshSession ssh.Session, message string) {
func logStatus() {
fmt := "%-20s %-20s %-20s"
log.Println()
log.Printf(fmt, "UID", "START_TIME", "TYPE")
log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
for uid, session := range state.sessions {
sessionType := session.sshSession.Subsystem()
if sessionType == "" {
sessionType = "ssh"
}
log.Printf(fmt, strconv.Itoa(uid),
log.Printf(fmt, uid,
session.startTime.Format(time.DateTime),
sessionType)
}

View File

@ -1,7 +1,9 @@
package comms
import (
"converge/pkg/websocketutil"
"net"
"strconv"
)
type AgentListener struct {
@ -12,19 +14,35 @@ func NewAgentListener(listener net.Listener) AgentListener {
return AgentListener{decorated: listener}
}
type LocalAddrHackConn struct {
net.Conn
localAddr net.Addr
}
func (conn LocalAddrHackConn) LocalAddr() net.Addr {
return conn.localAddr
}
func NewLocalAddrHackConn(conn net.Conn, clientId string) LocalAddrHackConn {
addr := LocalAddrHackConn{
localAddr: websocketutil.WebSocketAddr(clientId),
}
addr.Conn = conn
return addr
}
func (listener AgentListener) Accept() (net.Conn, error) {
conn, err := listener.decorated.Accept()
if err != nil {
return nil, err
}
//_, err = CheckProtocolVersion(Agent, conn)
//if err != nil {
// conn.Close()
// return nil, err
//}
return conn, nil
clientInfo, err := ReceiveClientInfo(conn)
if err != nil {
conn.Close()
return nil, err
}
return NewLocalAddrHackConn(conn, strconv.Itoa(clientInfo.ClientId)), nil
}
func (listener AgentListener) Close() error {

View File

@ -97,7 +97,6 @@ func ListenForAgentEvents(channel GOBChannel,
err := channel.Decoder.Decode(&result)
if err != nil {
// TODO more clean solution, need to explicitly close when agent exits.
log.Printf("Exiting agent listener %v", err)
return
}
@ -129,7 +128,6 @@ func ListenForServerEvents(channel CommChannel) {
err := channel.SideChannel.Decoder.Decode(&result)
if err != nil {
// TODO more clean solution, need to explicitly close when agent exits.
log.Printf("Exiting agent listener %v", err)
return
}
@ -156,7 +154,7 @@ func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, e
return ServerInfo{}, nil
}
// TODO remove logging
log.Println("Server info received: ", serverInfo)
log.Println("Agent configuration received from server")
return serverInfo, err
}
@ -185,7 +183,6 @@ func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo,
// is terminated.
func CheckProtocolVersion(role Role, channel GOBChannel) error {
log.Println("ROLE ", role)
switch role {
case Agent:
err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
@ -259,3 +256,17 @@ func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
// metadata before the connection is handed back to SSH.
func SendClientInfo(conn io.ReadWriter, info ClientInfo) error {
channel := NewGOBChannel(conn)
return SendWithTimeout(channel, info)
}
func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) {
channel := NewGOBChannel(conn)
clientInfo, err := ReceiveWithTimeout[ClientInfo](channel)
if err != nil {
return ClientInfo{}, err
}
return clientInfo, nil
}

View File

@ -24,7 +24,7 @@ type AgentInfo struct {
}
type ClientInfo struct {
ClientId string
ClientId int
}
type SessionInfo struct {

View File

@ -0,0 +1,22 @@
package concurrency
import "sync"
type AtomicCounter struct {
mutex sync.Mutex
lastValue int
}
func NewAtomicCounter() *AtomicCounter {
return &AtomicCounter{
mutex: sync.Mutex{},
lastValue: 0,
}
}
func (counter *AtomicCounter) IncrementAndGet() int {
counter.mutex.Lock()
defer counter.mutex.Unlock()
counter.lastValue++
return counter.lastValue
}

View File

@ -2,11 +2,13 @@ package converge
import (
"converge/pkg/comms"
"converge/pkg/concurrency"
"converge/pkg/iowrappers"
"fmt"
"io"
"log"
"net"
"strconv"
"sync"
"time"
)
@ -21,8 +23,11 @@ type Agent struct {
expiryTime time.Time
}
var clientIdGenerator = concurrency.NewAtomicCounter()
type Client struct {
publicId string
clientId int
agent net.Conn
client iowrappers.ReadWriteAddrCloser
startTime time.Time
@ -42,6 +47,7 @@ func NewClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser,
agentConn net.Conn) *Client {
return &Client{
publicId: publicId,
clientId: clientIdGenerator.IncrementAndGet(),
agent: agentConn,
client: clientConn,
startTime: time.Now(),
@ -78,10 +84,12 @@ func (admin *Admin) logStatus() {
agent.agentInfo.OS)
}
log.Println("")
fmt = "%-20s %-20s %-20s %-20s\n"
log.Printf(fmt, "CLIENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
fmt = "%-10s %-20s %-20s %-20s %-20s\n"
log.Printf(fmt, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
for _, client := range admin.clients {
log.Printf(fmt, client.publicId,
log.Printf(fmt,
strconv.Itoa(client.clientId),
client.publicId,
client.startTime.Format(time.DateTime),
client.client.RemoteAddr(),
client.sessionType)
@ -124,8 +132,19 @@ func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAd
return nil, err
}
log.Println("Successful websocket connection to agent")
log.Println("Sending connection information to agent")
client := NewClient(publicId, clientConn, agentConn)
// Before using this connection for SSH we use it to send client metadata to the
// agent
err = comms.SendClientInfo(agentConn, comms.ClientInfo{
ClientId: client.clientId,
})
if err != nil {
return nil, err
}
admin.clients = append(admin.clients, client)
admin.logStatus()
return client, nil