Unique ids for clients generated by converge server and made available to the ssh session through a net.Conn extension that passes the ID to the SSH session through the LocalAddr().
This commit is contained in:
parent
9d0675b2f2
commit
788050df32
@ -32,7 +32,7 @@ import (
|
||||
var hostPrivateKey []byte
|
||||
|
||||
func SftpHandler(sess ssh.Session) {
|
||||
uid := int(time.Now().UnixMicro())
|
||||
uid := sess.LocalAddr().String()
|
||||
agent.Login(uid, sess)
|
||||
defer agent.LogOut(uid)
|
||||
|
||||
@ -66,12 +66,11 @@ func sshServer(hostKeyFile string, shellCommand string,
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
uid := int(time.Now().UnixMilli())
|
||||
uid := s.LocalAddr().String()
|
||||
agent.Login(uid, s)
|
||||
iowrappers.SynchronizeStreams(process.Pipe(), s)
|
||||
agent.LogOut(uid)
|
||||
process.Wait()
|
||||
process.Wait()
|
||||
})
|
||||
|
||||
log.Println("starting ssh server, waiting for debug sessions")
|
||||
|
@ -12,7 +12,6 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
@ -49,7 +48,7 @@ type AgentState struct {
|
||||
ticker *time.Ticker
|
||||
|
||||
// map of unique session id to a session
|
||||
sessions map[int]*AgentSession
|
||||
sessions map[string]*AgentSession
|
||||
|
||||
lastUserLoginTime time.Time
|
||||
agentUsed bool
|
||||
@ -92,7 +91,7 @@ func ConfigureAgent(commChannel comms.CommChannel,
|
||||
lastExpiryTimmeReported: time.Time{},
|
||||
tickerInterval: tickerInterval,
|
||||
ticker: time.NewTicker(tickerInterval),
|
||||
sessions: make(map[int]*AgentSession),
|
||||
sessions: make(map[string]*AgentSession),
|
||||
|
||||
lastUserLoginTime: time.Time{},
|
||||
agentUsed: false,
|
||||
@ -122,11 +121,11 @@ func ConfigureAgent(commChannel comms.CommChannel,
|
||||
|
||||
}
|
||||
|
||||
func Login(sessionId int, sshSession ssh.Session) {
|
||||
func Login(sessionId string, sshSession ssh.Session) {
|
||||
events <- async.Async(login, sessionId, sshSession)
|
||||
}
|
||||
|
||||
func LogOut(sessionId int) {
|
||||
func LogOut(sessionId string) {
|
||||
events <- async.Async(logOut, sessionId)
|
||||
}
|
||||
|
||||
@ -176,7 +175,7 @@ func holdFileMessage() string {
|
||||
return message
|
||||
}
|
||||
|
||||
func login(sessionId int, sshSession ssh.Session) {
|
||||
func login(sessionId string, sshSession ssh.Session) {
|
||||
log.Println("New login")
|
||||
hostname, _ := os.Hostname()
|
||||
|
||||
@ -238,7 +237,7 @@ func formatHelpMessage() string {
|
||||
return helpFormatted
|
||||
}
|
||||
|
||||
func logOut(sessionId int) {
|
||||
func logOut(sessionId string) {
|
||||
log.Println("User logged out")
|
||||
delete(state.sessions, sessionId)
|
||||
logStatus()
|
||||
@ -255,13 +254,13 @@ func printMessage(sshSession ssh.Session, message string) {
|
||||
func logStatus() {
|
||||
fmt := "%-20s %-20s %-20s"
|
||||
log.Println()
|
||||
log.Printf(fmt, "UID", "START_TIME", "TYPE")
|
||||
log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
|
||||
for uid, session := range state.sessions {
|
||||
sessionType := session.sshSession.Subsystem()
|
||||
if sessionType == "" {
|
||||
sessionType = "ssh"
|
||||
}
|
||||
log.Printf(fmt, strconv.Itoa(uid),
|
||||
log.Printf(fmt, uid,
|
||||
session.startTime.Format(time.DateTime),
|
||||
sessionType)
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
package comms
|
||||
|
||||
import (
|
||||
"converge/pkg/websocketutil"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type AgentListener struct {
|
||||
@ -12,19 +14,35 @@ func NewAgentListener(listener net.Listener) AgentListener {
|
||||
return AgentListener{decorated: listener}
|
||||
}
|
||||
|
||||
type LocalAddrHackConn struct {
|
||||
net.Conn
|
||||
localAddr net.Addr
|
||||
}
|
||||
|
||||
func (conn LocalAddrHackConn) LocalAddr() net.Addr {
|
||||
return conn.localAddr
|
||||
}
|
||||
|
||||
func NewLocalAddrHackConn(conn net.Conn, clientId string) LocalAddrHackConn {
|
||||
addr := LocalAddrHackConn{
|
||||
localAddr: websocketutil.WebSocketAddr(clientId),
|
||||
}
|
||||
addr.Conn = conn
|
||||
return addr
|
||||
}
|
||||
|
||||
func (listener AgentListener) Accept() (net.Conn, error) {
|
||||
conn, err := listener.decorated.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//_, err = CheckProtocolVersion(Agent, conn)
|
||||
//if err != nil {
|
||||
// conn.Close()
|
||||
// return nil, err
|
||||
//}
|
||||
|
||||
return conn, nil
|
||||
clientInfo, err := ReceiveClientInfo(conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return NewLocalAddrHackConn(conn, strconv.Itoa(clientInfo.ClientId)), nil
|
||||
}
|
||||
|
||||
func (listener AgentListener) Close() error {
|
||||
|
@ -97,7 +97,6 @@ func ListenForAgentEvents(channel GOBChannel,
|
||||
err := channel.Decoder.Decode(&result)
|
||||
|
||||
if err != nil {
|
||||
// TODO more clean solution, need to explicitly close when agent exits.
|
||||
log.Printf("Exiting agent listener %v", err)
|
||||
return
|
||||
}
|
||||
@ -129,7 +128,6 @@ func ListenForServerEvents(channel CommChannel) {
|
||||
err := channel.SideChannel.Decoder.Decode(&result)
|
||||
|
||||
if err != nil {
|
||||
// TODO more clean solution, need to explicitly close when agent exits.
|
||||
log.Printf("Exiting agent listener %v", err)
|
||||
return
|
||||
}
|
||||
@ -156,7 +154,7 @@ func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, e
|
||||
return ServerInfo{}, nil
|
||||
}
|
||||
// TODO remove logging
|
||||
log.Println("Server info received: ", serverInfo)
|
||||
log.Println("Agent configuration received from server")
|
||||
|
||||
return serverInfo, err
|
||||
}
|
||||
@ -185,7 +183,6 @@ func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo,
|
||||
// is terminated.
|
||||
|
||||
func CheckProtocolVersion(role Role, channel GOBChannel) error {
|
||||
log.Println("ROLE ", role)
|
||||
switch role {
|
||||
case Agent:
|
||||
err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
|
||||
@ -259,3 +256,17 @@ func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
|
||||
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
|
||||
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
|
||||
// metadata before the connection is handed back to SSH.
|
||||
|
||||
func SendClientInfo(conn io.ReadWriter, info ClientInfo) error {
|
||||
channel := NewGOBChannel(conn)
|
||||
return SendWithTimeout(channel, info)
|
||||
}
|
||||
|
||||
func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) {
|
||||
channel := NewGOBChannel(conn)
|
||||
clientInfo, err := ReceiveWithTimeout[ClientInfo](channel)
|
||||
if err != nil {
|
||||
return ClientInfo{}, err
|
||||
}
|
||||
return clientInfo, nil
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ type AgentInfo struct {
|
||||
}
|
||||
|
||||
type ClientInfo struct {
|
||||
ClientId string
|
||||
ClientId int
|
||||
}
|
||||
|
||||
type SessionInfo struct {
|
||||
|
22
pkg/concurrency/atomiccounter.go
Normal file
22
pkg/concurrency/atomiccounter.go
Normal file
@ -0,0 +1,22 @@
|
||||
package concurrency
|
||||
|
||||
import "sync"
|
||||
|
||||
type AtomicCounter struct {
|
||||
mutex sync.Mutex
|
||||
lastValue int
|
||||
}
|
||||
|
||||
func NewAtomicCounter() *AtomicCounter {
|
||||
return &AtomicCounter{
|
||||
mutex: sync.Mutex{},
|
||||
lastValue: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (counter *AtomicCounter) IncrementAndGet() int {
|
||||
counter.mutex.Lock()
|
||||
defer counter.mutex.Unlock()
|
||||
counter.lastValue++
|
||||
return counter.lastValue
|
||||
}
|
@ -2,11 +2,13 @@ package converge
|
||||
|
||||
import (
|
||||
"converge/pkg/comms"
|
||||
"converge/pkg/concurrency"
|
||||
"converge/pkg/iowrappers"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -21,8 +23,11 @@ type Agent struct {
|
||||
expiryTime time.Time
|
||||
}
|
||||
|
||||
var clientIdGenerator = concurrency.NewAtomicCounter()
|
||||
|
||||
type Client struct {
|
||||
publicId string
|
||||
clientId int
|
||||
agent net.Conn
|
||||
client iowrappers.ReadWriteAddrCloser
|
||||
startTime time.Time
|
||||
@ -42,6 +47,7 @@ func NewClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser,
|
||||
agentConn net.Conn) *Client {
|
||||
return &Client{
|
||||
publicId: publicId,
|
||||
clientId: clientIdGenerator.IncrementAndGet(),
|
||||
agent: agentConn,
|
||||
client: clientConn,
|
||||
startTime: time.Now(),
|
||||
@ -78,10 +84,12 @@ func (admin *Admin) logStatus() {
|
||||
agent.agentInfo.OS)
|
||||
}
|
||||
log.Println("")
|
||||
fmt = "%-20s %-20s %-20s %-20s\n"
|
||||
log.Printf(fmt, "CLIENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
|
||||
fmt = "%-10s %-20s %-20s %-20s %-20s\n"
|
||||
log.Printf(fmt, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
|
||||
for _, client := range admin.clients {
|
||||
log.Printf(fmt, client.publicId,
|
||||
log.Printf(fmt,
|
||||
strconv.Itoa(client.clientId),
|
||||
client.publicId,
|
||||
client.startTime.Format(time.DateTime),
|
||||
client.client.RemoteAddr(),
|
||||
client.sessionType)
|
||||
@ -124,8 +132,19 @@ func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAd
|
||||
return nil, err
|
||||
}
|
||||
log.Println("Successful websocket connection to agent")
|
||||
log.Println("Sending connection information to agent")
|
||||
|
||||
client := NewClient(publicId, clientConn, agentConn)
|
||||
|
||||
// Before using this connection for SSH we use it to send client metadata to the
|
||||
// agent
|
||||
err = comms.SendClientInfo(agentConn, comms.ClientInfo{
|
||||
ClientId: client.clientId,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
admin.clients = append(admin.clients, client)
|
||||
admin.logStatus()
|
||||
return client, nil
|
||||
|
Loading…
Reference in New Issue
Block a user