Unique ids for clients generated by converge server and made available to the ssh session through a net.Conn extension that passes the ID to the SSH session through the LocalAddr().
This commit is contained in:
parent
9d0675b2f2
commit
788050df32
@ -32,7 +32,7 @@ import (
|
|||||||
var hostPrivateKey []byte
|
var hostPrivateKey []byte
|
||||||
|
|
||||||
func SftpHandler(sess ssh.Session) {
|
func SftpHandler(sess ssh.Session) {
|
||||||
uid := int(time.Now().UnixMicro())
|
uid := sess.LocalAddr().String()
|
||||||
agent.Login(uid, sess)
|
agent.Login(uid, sess)
|
||||||
defer agent.LogOut(uid)
|
defer agent.LogOut(uid)
|
||||||
|
|
||||||
@ -66,12 +66,11 @@ func sshServer(hostKeyFile string, shellCommand string,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
uid := int(time.Now().UnixMilli())
|
uid := s.LocalAddr().String()
|
||||||
agent.Login(uid, s)
|
agent.Login(uid, s)
|
||||||
iowrappers.SynchronizeStreams(process.Pipe(), s)
|
iowrappers.SynchronizeStreams(process.Pipe(), s)
|
||||||
agent.LogOut(uid)
|
agent.LogOut(uid)
|
||||||
process.Wait()
|
process.Wait()
|
||||||
process.Wait()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
log.Println("starting ssh server, waiting for debug sessions")
|
log.Println("starting ssh server, waiting for debug sessions")
|
||||||
|
@ -12,7 +12,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
@ -49,7 +48,7 @@ type AgentState struct {
|
|||||||
ticker *time.Ticker
|
ticker *time.Ticker
|
||||||
|
|
||||||
// map of unique session id to a session
|
// map of unique session id to a session
|
||||||
sessions map[int]*AgentSession
|
sessions map[string]*AgentSession
|
||||||
|
|
||||||
lastUserLoginTime time.Time
|
lastUserLoginTime time.Time
|
||||||
agentUsed bool
|
agentUsed bool
|
||||||
@ -92,7 +91,7 @@ func ConfigureAgent(commChannel comms.CommChannel,
|
|||||||
lastExpiryTimmeReported: time.Time{},
|
lastExpiryTimmeReported: time.Time{},
|
||||||
tickerInterval: tickerInterval,
|
tickerInterval: tickerInterval,
|
||||||
ticker: time.NewTicker(tickerInterval),
|
ticker: time.NewTicker(tickerInterval),
|
||||||
sessions: make(map[int]*AgentSession),
|
sessions: make(map[string]*AgentSession),
|
||||||
|
|
||||||
lastUserLoginTime: time.Time{},
|
lastUserLoginTime: time.Time{},
|
||||||
agentUsed: false,
|
agentUsed: false,
|
||||||
@ -122,11 +121,11 @@ func ConfigureAgent(commChannel comms.CommChannel,
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Login(sessionId int, sshSession ssh.Session) {
|
func Login(sessionId string, sshSession ssh.Session) {
|
||||||
events <- async.Async(login, sessionId, sshSession)
|
events <- async.Async(login, sessionId, sshSession)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LogOut(sessionId int) {
|
func LogOut(sessionId string) {
|
||||||
events <- async.Async(logOut, sessionId)
|
events <- async.Async(logOut, sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,7 +175,7 @@ func holdFileMessage() string {
|
|||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
func login(sessionId int, sshSession ssh.Session) {
|
func login(sessionId string, sshSession ssh.Session) {
|
||||||
log.Println("New login")
|
log.Println("New login")
|
||||||
hostname, _ := os.Hostname()
|
hostname, _ := os.Hostname()
|
||||||
|
|
||||||
@ -238,7 +237,7 @@ func formatHelpMessage() string {
|
|||||||
return helpFormatted
|
return helpFormatted
|
||||||
}
|
}
|
||||||
|
|
||||||
func logOut(sessionId int) {
|
func logOut(sessionId string) {
|
||||||
log.Println("User logged out")
|
log.Println("User logged out")
|
||||||
delete(state.sessions, sessionId)
|
delete(state.sessions, sessionId)
|
||||||
logStatus()
|
logStatus()
|
||||||
@ -255,13 +254,13 @@ func printMessage(sshSession ssh.Session, message string) {
|
|||||||
func logStatus() {
|
func logStatus() {
|
||||||
fmt := "%-20s %-20s %-20s"
|
fmt := "%-20s %-20s %-20s"
|
||||||
log.Println()
|
log.Println()
|
||||||
log.Printf(fmt, "UID", "START_TIME", "TYPE")
|
log.Printf(fmt, "CLIENT", "START_TIME", "TYPE")
|
||||||
for uid, session := range state.sessions {
|
for uid, session := range state.sessions {
|
||||||
sessionType := session.sshSession.Subsystem()
|
sessionType := session.sshSession.Subsystem()
|
||||||
if sessionType == "" {
|
if sessionType == "" {
|
||||||
sessionType = "ssh"
|
sessionType = "ssh"
|
||||||
}
|
}
|
||||||
log.Printf(fmt, strconv.Itoa(uid),
|
log.Printf(fmt, uid,
|
||||||
session.startTime.Format(time.DateTime),
|
session.startTime.Format(time.DateTime),
|
||||||
sessionType)
|
sessionType)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package comms
|
package comms
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"converge/pkg/websocketutil"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AgentListener struct {
|
type AgentListener struct {
|
||||||
@ -12,19 +14,35 @@ func NewAgentListener(listener net.Listener) AgentListener {
|
|||||||
return AgentListener{decorated: listener}
|
return AgentListener{decorated: listener}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type LocalAddrHackConn struct {
|
||||||
|
net.Conn
|
||||||
|
localAddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn LocalAddrHackConn) LocalAddr() net.Addr {
|
||||||
|
return conn.localAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLocalAddrHackConn(conn net.Conn, clientId string) LocalAddrHackConn {
|
||||||
|
addr := LocalAddrHackConn{
|
||||||
|
localAddr: websocketutil.WebSocketAddr(clientId),
|
||||||
|
}
|
||||||
|
addr.Conn = conn
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
func (listener AgentListener) Accept() (net.Conn, error) {
|
func (listener AgentListener) Accept() (net.Conn, error) {
|
||||||
conn, err := listener.decorated.Accept()
|
conn, err := listener.decorated.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
//_, err = CheckProtocolVersion(Agent, conn)
|
clientInfo, err := ReceiveClientInfo(conn)
|
||||||
//if err != nil {
|
if err != nil {
|
||||||
// conn.Close()
|
conn.Close()
|
||||||
// return nil, err
|
return nil, err
|
||||||
//}
|
}
|
||||||
|
return NewLocalAddrHackConn(conn, strconv.Itoa(clientInfo.ClientId)), nil
|
||||||
return conn, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (listener AgentListener) Close() error {
|
func (listener AgentListener) Close() error {
|
||||||
|
@ -97,7 +97,6 @@ func ListenForAgentEvents(channel GOBChannel,
|
|||||||
err := channel.Decoder.Decode(&result)
|
err := channel.Decoder.Decode(&result)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO more clean solution, need to explicitly close when agent exits.
|
|
||||||
log.Printf("Exiting agent listener %v", err)
|
log.Printf("Exiting agent listener %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -129,7 +128,6 @@ func ListenForServerEvents(channel CommChannel) {
|
|||||||
err := channel.SideChannel.Decoder.Decode(&result)
|
err := channel.SideChannel.Decoder.Decode(&result)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO more clean solution, need to explicitly close when agent exits.
|
|
||||||
log.Printf("Exiting agent listener %v", err)
|
log.Printf("Exiting agent listener %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -156,7 +154,7 @@ func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, e
|
|||||||
return ServerInfo{}, nil
|
return ServerInfo{}, nil
|
||||||
}
|
}
|
||||||
// TODO remove logging
|
// TODO remove logging
|
||||||
log.Println("Server info received: ", serverInfo)
|
log.Println("Agent configuration received from server")
|
||||||
|
|
||||||
return serverInfo, err
|
return serverInfo, err
|
||||||
}
|
}
|
||||||
@ -185,7 +183,6 @@ func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo,
|
|||||||
// is terminated.
|
// is terminated.
|
||||||
|
|
||||||
func CheckProtocolVersion(role Role, channel GOBChannel) error {
|
func CheckProtocolVersion(role Role, channel GOBChannel) error {
|
||||||
log.Println("ROLE ", role)
|
|
||||||
switch role {
|
switch role {
|
||||||
case Agent:
|
case Agent:
|
||||||
err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
|
err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
|
||||||
@ -259,3 +256,17 @@ func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
|
|||||||
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
|
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
|
||||||
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
|
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
|
||||||
// metadata before the connection is handed back to SSH.
|
// metadata before the connection is handed back to SSH.
|
||||||
|
|
||||||
|
func SendClientInfo(conn io.ReadWriter, info ClientInfo) error {
|
||||||
|
channel := NewGOBChannel(conn)
|
||||||
|
return SendWithTimeout(channel, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) {
|
||||||
|
channel := NewGOBChannel(conn)
|
||||||
|
clientInfo, err := ReceiveWithTimeout[ClientInfo](channel)
|
||||||
|
if err != nil {
|
||||||
|
return ClientInfo{}, err
|
||||||
|
}
|
||||||
|
return clientInfo, nil
|
||||||
|
}
|
||||||
|
@ -24,7 +24,7 @@ type AgentInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ClientInfo struct {
|
type ClientInfo struct {
|
||||||
ClientId string
|
ClientId int
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionInfo struct {
|
type SessionInfo struct {
|
||||||
|
22
pkg/concurrency/atomiccounter.go
Normal file
22
pkg/concurrency/atomiccounter.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package concurrency
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
type AtomicCounter struct {
|
||||||
|
mutex sync.Mutex
|
||||||
|
lastValue int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAtomicCounter() *AtomicCounter {
|
||||||
|
return &AtomicCounter{
|
||||||
|
mutex: sync.Mutex{},
|
||||||
|
lastValue: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (counter *AtomicCounter) IncrementAndGet() int {
|
||||||
|
counter.mutex.Lock()
|
||||||
|
defer counter.mutex.Unlock()
|
||||||
|
counter.lastValue++
|
||||||
|
return counter.lastValue
|
||||||
|
}
|
@ -2,11 +2,13 @@ package converge
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"converge/pkg/comms"
|
"converge/pkg/comms"
|
||||||
|
"converge/pkg/concurrency"
|
||||||
"converge/pkg/iowrappers"
|
"converge/pkg/iowrappers"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -21,8 +23,11 @@ type Agent struct {
|
|||||||
expiryTime time.Time
|
expiryTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var clientIdGenerator = concurrency.NewAtomicCounter()
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
publicId string
|
publicId string
|
||||||
|
clientId int
|
||||||
agent net.Conn
|
agent net.Conn
|
||||||
client iowrappers.ReadWriteAddrCloser
|
client iowrappers.ReadWriteAddrCloser
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
@ -42,6 +47,7 @@ func NewClient(publicId string, clientConn iowrappers.ReadWriteAddrCloser,
|
|||||||
agentConn net.Conn) *Client {
|
agentConn net.Conn) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
publicId: publicId,
|
publicId: publicId,
|
||||||
|
clientId: clientIdGenerator.IncrementAndGet(),
|
||||||
agent: agentConn,
|
agent: agentConn,
|
||||||
client: clientConn,
|
client: clientConn,
|
||||||
startTime: time.Now(),
|
startTime: time.Now(),
|
||||||
@ -78,10 +84,12 @@ func (admin *Admin) logStatus() {
|
|||||||
agent.agentInfo.OS)
|
agent.agentInfo.OS)
|
||||||
}
|
}
|
||||||
log.Println("")
|
log.Println("")
|
||||||
fmt = "%-20s %-20s %-20s %-20s\n"
|
fmt = "%-10s %-20s %-20s %-20s %-20s\n"
|
||||||
log.Printf(fmt, "CLIENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
|
log.Printf(fmt, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE")
|
||||||
for _, client := range admin.clients {
|
for _, client := range admin.clients {
|
||||||
log.Printf(fmt, client.publicId,
|
log.Printf(fmt,
|
||||||
|
strconv.Itoa(client.clientId),
|
||||||
|
client.publicId,
|
||||||
client.startTime.Format(time.DateTime),
|
client.startTime.Format(time.DateTime),
|
||||||
client.client.RemoteAddr(),
|
client.client.RemoteAddr(),
|
||||||
client.sessionType)
|
client.sessionType)
|
||||||
@ -124,8 +132,19 @@ func (admin *Admin) addClient(publicId string, clientConn iowrappers.ReadWriteAd
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
log.Println("Successful websocket connection to agent")
|
log.Println("Successful websocket connection to agent")
|
||||||
|
log.Println("Sending connection information to agent")
|
||||||
|
|
||||||
client := NewClient(publicId, clientConn, agentConn)
|
client := NewClient(publicId, clientConn, agentConn)
|
||||||
|
|
||||||
|
// Before using this connection for SSH we use it to send client metadata to the
|
||||||
|
// agent
|
||||||
|
err = comms.SendClientInfo(agentConn, comms.ClientInfo{
|
||||||
|
ClientId: client.clientId,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
admin.clients = append(admin.clients, client)
|
admin.clients = append(admin.clients, client)
|
||||||
admin.logStatus()
|
admin.logStatus()
|
||||||
return client, nil
|
return client, nil
|
||||||
|
Loading…
Reference in New Issue
Block a user