initialization of username, password on client (from server) and initialization of agentinfo on server is now done as soon as the agent registered and not through a side channel.

Making use of some simple utilities for GOB to make it easy to send objects over the line.
This commit is contained in:
Erik Brakkee 2024-07-27 20:46:53 +02:00
parent ada34495ef
commit 9d0675b2f2
6 changed files with 188 additions and 63 deletions

View File

@ -14,7 +14,6 @@ import (
"github.com/pkg/sftp" "github.com/pkg/sftp"
"io" "io"
"log" "log"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -33,7 +32,7 @@ import (
var hostPrivateKey []byte var hostPrivateKey []byte
func SftpHandler(sess ssh.Session) { func SftpHandler(sess ssh.Session) {
uid := int(time.Now().UnixMilli()) uid := int(time.Now().UnixMicro())
agent.Login(uid, sess) agent.Login(uid, sess)
defer agent.LogOut(uid) defer agent.LogOut(uid)
@ -249,8 +248,9 @@ func main() {
wsConn := websocketutil.NewWebSocketConn(conn) wsConn := websocketutil.NewWebSocketConn(conn)
defer wsConn.Close() defer wsConn.Close()
err = comms.CheckProtocolVersion(comms.Agent, wsConn) serverInfo, err := comms.AgentInitialization(wsConn, comms.NewAgentInfo())
if err != nil { if err != nil {
log.Printf("ERROR: %+v", err)
os.Exit(1) os.Exit(1)
} }
@ -261,7 +261,10 @@ func main() {
// Authentiocation // Authentiocation
sshUserCredentials, passwordHandler, authorizedKeys := setupAuthentication(commChannel, authorizedKeysFile) passwordHandler, authorizedKeys := setupAuthentication(
commChannel,
serverInfo.UserPassword,
authorizedKeysFile)
// Choose shell // Choose shell
@ -279,9 +282,9 @@ func main() {
log.Println() log.Println()
clientUrl := strings.ReplaceAll(wsURL, "/agent/", "/client/") clientUrl := strings.ReplaceAll(wsURL, "/agent/", "/client/")
sshCommand := fmt.Sprintf("ssh -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost", sshCommand := fmt.Sprintf("ssh -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost",
clientUrl, sshUserCredentials.Username) clientUrl, serverInfo.UserPassword.Username)
sftpCommand := fmt.Sprintf("sftp -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost", sftpCommand := fmt.Sprintf("sftp -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost",
clientUrl, sshUserCredentials.Username) clientUrl, serverInfo.UserPassword.Username)
log.Println(" # For SSH") log.Println(" # For SSH")
log.Println(" " + sshCommand) log.Println(" " + sshCommand)
log.Println() log.Println()
@ -302,26 +305,20 @@ func main() {
service.Run(listener) service.Run(listener)
} }
func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) { func setupAuthentication(commChannel comms.CommChannel,
// Random user name and password so that effectively no one can login userPassword comms.UserPassword,
// until the user and password have been received from the server. authorizedKeysFile string) (func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) {
sshUserCredentials := comms.UserPassword{
Username: strconv.Itoa(rand.Int()),
Password: strconv.Itoa(rand.Int()),
}
passwordHandler := func(ctx ssh.Context, password string) bool { passwordHandler := func(ctx ssh.Context, password string) bool {
// Replace with your own logic to validate username and password // Replace with your own logic to validate username and password
return ctx.User() == sshUserCredentials.Username && password == sshUserCredentials.Password return ctx.User() == userPassword.Username && password == userPassword.Password
} }
go comms.ListenForServerEvents(commChannel, func(user comms.UserPassword) { go comms.ListenForServerEvents(commChannel)
log.Println("Username and password configuration received from server")
sshUserCredentials = user
})
authorizedKeys := ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile) authorizedKeys := ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile)
if len(authorizedKeys.keys) > 0 { if len(authorizedKeys.keys) > 0 {
log.Printf("A total of %d authorized ssh keys were found", len(authorizedKeys.keys)) log.Printf("A total of %d authorized ssh keys were found", len(authorizedKeys.keys))
} }
return sshUserCredentials, passwordHandler, authorizedKeys return passwordHandler, authorizedKeys
} }
func chooseShell() string { func chooseShell() string {

View File

@ -101,8 +101,10 @@ func ConfigureAgent(commChannel comms.CommChannel,
log.Printf("Agent expires at %s", log.Printf("Agent expires at %s",
state.expiryTime(holdFilename).Format(time.DateTime)) state.expiryTime(holdFilename).Format(time.DateTime))
state.commChannel.SideChannel.Send(comms.NewAgentInfo()) comms.Send(state.commChannel.SideChannel,
state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) comms.ConvergeMessage{
Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)),
})
go func() { go func() {
for { for {
@ -182,7 +184,10 @@ func login(sessionId int, sshSession ssh.Session) {
if sessionType == "" { if sessionType == "" {
sessionType = "ssh" sessionType = "ssh"
} }
state.commChannel.SideChannel.Send(comms.NewSessionInfo(sessionType)) comms.Send(state.commChannel.SideChannel,
comms.ConvergeMessage{
Value: comms.NewSessionInfo(sessionType),
})
holdFileStats, ok := fileExistsWithStats(holdFilename) holdFileStats, ok := fileExistsWithStats(holdFilename)
if ok { if ok {
@ -297,7 +302,10 @@ func holdFileChange() {
message += holdFileMessage() message += holdFileMessage()
messageUsers(message) messageUsers(message)
state.lastExpiryTimmeReported = newExpiryTIme state.lastExpiryTimmeReported = newExpiryTIme
state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) comms.Send(state.commChannel.SideChannel,
comms.ConvergeMessage{
Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)),
})
} }
} }

View File

@ -8,6 +8,8 @@ import (
"time" "time"
) )
const MESSAGE_TIMEOUT = 10 * time.Second
type CommChannel struct { type CommChannel struct {
// a separet connection outside of the ssh session // a separet connection outside of the ssh session
SideChannel GOBChannel SideChannel GOBChannel
@ -70,7 +72,10 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
go func() { go func() {
for { for {
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
err := commChannel.SideChannel.Send(HeartBeat{}) err := Send(commChannel.SideChannel,
ConvergeMessage{
Value: HeartBeat{},
})
if err != nil { if err != nil {
log.Println("Sending heartbeat to server failed") log.Println("Sending heartbeat to server failed")
} }
@ -113,13 +118,12 @@ func ListenForAgentEvents(channel GOBChannel,
// intended to keep the connection up // intended to keep the connection up
default: default:
fmt.Printf(" Unknown type: %T\n", v) fmt.Printf(" Unknown type: %v %T\n", v, v)
} }
} }
} }
func ListenForServerEvents(channel CommChannel, func ListenForServerEvents(channel CommChannel) {
setUsernamePassword func(user UserPassword)) {
for { for {
var result ConvergeMessage var result ConvergeMessage
err := channel.SideChannel.Decoder.Decode(&result) err := channel.SideChannel.Decoder.Decode(&result)
@ -129,10 +133,9 @@ func ListenForServerEvents(channel CommChannel,
log.Printf("Exiting agent listener %v", err) log.Printf("Exiting agent listener %v", err)
return return
} }
switch v := result.Value.(type) {
case UserPassword: // no supported server events at this time.
setUsernamePassword(v) switch v := result.Value.(type) {
default: default:
fmt.Printf(" Unknown type: %T\n", v) fmt.Printf(" Unknown type: %T\n", v)
@ -140,25 +143,102 @@ func ListenForServerEvents(channel CommChannel,
} }
} }
func CheckProtocolVersion(role Role, conn io.ReadWriter) error { func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, error) {
channel := NewGOBChannel(conn)
err := CheckProtocolVersion(Agent, channel)
err = SendWithTimeout(channel, agentInto)
if err != nil {
return ServerInfo{}, nil
}
serverInfo, err := ReceiveWithTimeout[ServerInfo](channel)
if err != nil {
return ServerInfo{}, nil
}
// TODO remove logging
log.Println("Server info received: ", serverInfo)
return serverInfo, err
}
func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo, error) {
channel := NewGOBChannel(conn)
err := CheckProtocolVersion(ConvergeServer, channel)
agentInfo, err := ReceiveWithTimeout[AgentInfo](channel)
if err != nil {
return AgentInfo{}, err
}
log.Println("Agent info received: ", agentInfo)
err = SendWithTimeout(channel, serverInfo)
if err != nil {
return AgentInfo{}, nil
}
return agentInfo, err
}
// Events sent over the websocket connection that is established between
// agent and converge server. This is done as soon as the agent starts.
// First commmunication between agent and Converge Server
// Both exchange their protocol version and if it is incorrect, the session
// is terminated.
func CheckProtocolVersion(role Role, channel GOBChannel) error {
log.Println("ROLE ", role)
switch role {
case Agent:
err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
if err != nil {
return err
}
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
if err != nil {
return err
}
if version.Version != PROTOCOL_VERSION {
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
PROTOCOL_VERSION, version.Version)
}
return nil
case ConvergeServer:
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
if err != nil {
return err
}
err = SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
if err != nil {
return err
}
if version.Version != PROTOCOL_VERSION {
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
PROTOCOL_VERSION, version.Version)
}
return nil
default:
panic(fmt.Errorf("unexpected rolg %v", role))
}
}
func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
channel := NewGOBChannel(conn) channel := NewGOBChannel(conn)
sends := make(chan any) sends := make(chan bool)
receives := make(chan any) receives := make(chan ProtocolVersion)
errors := make(chan error) errors := make(chan error)
channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors) SendAsync(channel, ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
channel.ReceiveAsync(receives, errors) ReceiveAsync(channel, receives, errors)
select { select {
case <-time.After(10 * time.Second): case <-time.After(MESSAGE_TIMEOUT):
log.Println("PROTOCOLVERSION: timeout") log.Println("PROTOCOLVERSION: timeout")
return fmt.Errorf("Timeout waiting for protocol version") return fmt.Errorf("Timeout waiting for protocol version")
case err := <-errors: case err := <-errors:
log.Printf("PROTOCOLVERSION: %v", err) log.Printf("PROTOCOLVERSION: %v", err)
return err return err
case protocolVersion := <-receives: case protocolVersion := <-receives:
otherVersion := protocolVersion.(ProtocolVersion).Version otherVersion := protocolVersion.Version
if PROTOCOL_VERSION != otherVersion { if PROTOCOL_VERSION != otherVersion {
switch role { switch role {
case Agent: case Agent:
@ -170,7 +250,12 @@ func CheckProtocolVersion(role Role, conn io.ReadWriter) error {
} }
return fmt.Errorf("Protocol version mismatch") return fmt.Errorf("Protocol version mismatch")
} }
log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version) log.Printf("PROTOCOLVERSION: %v", protocolVersion.Version)
return nil return nil
} }
} }
// Session info metadata exchange. These are sent over the SSH connection. The agent embedded
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
// metadata before the connection is handed back to SSH.

View File

@ -23,6 +23,10 @@ type AgentInfo struct {
OS string OS string
} }
type ClientInfo struct {
ClientId string
}
type SessionInfo struct { type SessionInfo struct {
// "ssh", "sftp" // "ssh", "sftp"
SessionType string SessionType string
@ -47,8 +51,7 @@ type UserPassword struct {
Password string Password string
} }
type ConnectionInfo struct { type ServerInfo struct {
ConnectionId int
UserPassword UserPassword UserPassword UserPassword
} }
@ -88,7 +91,6 @@ func RegisterEventsWithGob() {
// ConvergeServer to Agent // ConvergeServer to Agent
gob.Register(ProtocolVersion{}) gob.Register(ProtocolVersion{})
gob.Register(UserPassword{}) gob.Register(UserPassword{})
gob.Register(ConnectionInfo{})
// Wrapper event. // Wrapper event.
gob.Register(ConvergeMessage{}) gob.Register(ConvergeMessage{})

View File

@ -2,8 +2,10 @@ package comms
import ( import (
"encoding/gob" "encoding/gob"
"fmt"
"io" "io"
"log" "log"
"time"
) )
type GOBChannel struct { type GOBChannel struct {
@ -22,12 +24,26 @@ func NewGOBChannel(conn io.ReadWriter) GOBChannel {
} }
} }
func Send(channel GOBChannel, object any) error {
err := channel.Encoder.Encode(object)
if err != nil {
log.Printf("Encoding error %v", err)
}
return err
}
func Receive[T any](channel GOBChannel) (T, error) {
target := *new(T)
err := channel.Decoder.Decode(&target)
return target, err
}
// Asynchronous send and receive on a single connection is guaranteed to preserver ordering of // Asynchronous send and receive on a single connection is guaranteed to preserver ordering of
// messages. We use asynchronous to void blocking indefinitely or depending on network timeouts. // messages. We use asynchronous to void blocking indefinitely or depending on network timeouts.
func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- error) { func SendAsync[T any](channel GOBChannel, obj T, done chan<- bool, errors chan<- error) {
go func() { go func() {
err := channel.Send(obj) err := Send(channel, obj)
if err != nil { if err != nil {
errors <- err errors <- err
} else { } else {
@ -36,9 +52,9 @@ func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- erro
}() }()
} }
func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) { func ReceiveAsync[T any](channel GOBChannel, result chan T, errors chan<- error) {
go func() { go func() {
value, err := channel.Receive() value, err := Receive[T](channel)
if err != nil { if err != nil {
errors <- err errors <- err
} else { } else {
@ -47,16 +63,32 @@ func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) {
}() }()
} }
func (channel GOBChannel) Send(object any) error { func SendWithTimeout[T any](channel GOBChannel, obj T) error {
err := channel.Encoder.Encode(ConvergeMessage{Value: object}) done := make(chan bool)
if err != nil { errors := make(chan error)
log.Printf("Encoding error %v", err)
} SendAsync(channel, obj, done, errors)
select {
case <-time.After(MESSAGE_TIMEOUT):
return fmt.Errorf("Timeout in SwndWithTimout")
case err := <-errors:
return err return err
case <-done:
return nil
}
} }
func (channel GOBChannel) Receive() (any, error) { func ReceiveWithTimeout[T any](channel GOBChannel) (T, error) {
var target ConvergeMessage result := make(chan T)
err := channel.Decoder.Decode(&target) errors := make(chan error)
return target.Value, err
ReceiveAsync(channel, result, errors)
select {
case <-time.After(MESSAGE_TIMEOUT):
return *new(T), fmt.Errorf("Timeout in ReceiveWithTimout")
case err := <-errors:
return *new(T), err
case value := <-result:
return value, nil
}
} }

View File

@ -29,11 +29,12 @@ type Client struct {
sessionType string sessionType string
} }
func NewAgent(commChannel comms.CommChannel, publicId string) *Agent { func NewAgent(commChannel comms.CommChannel, publicId string, agentInfo comms.AgentInfo) *Agent {
return &Agent{ return &Agent{
commChannel: commChannel, commChannel: commChannel,
publicId: publicId, publicId: publicId,
startTime: time.Now(), startTime: time.Now(),
agentInfo: agentInfo,
} }
} }
@ -88,7 +89,7 @@ func (admin *Admin) logStatus() {
log.Printf("\n") log.Printf("\n")
} }
func (admin *Admin) addAgent(publicId string, conn io.ReadWriteCloser) (*Agent, error) { func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io.ReadWriteCloser) (*Agent, error) {
admin.mutex.Lock() admin.mutex.Lock()
defer admin.mutex.Unlock() defer admin.mutex.Unlock()
@ -102,7 +103,7 @@ func (admin *Admin) addAgent(publicId string, conn io.ReadWriteCloser) (*Agent,
if err != nil { if err != nil {
return nil, err return nil, err
} }
agent = NewAgent(commChannel, publicId) agent = NewAgent(commChannel, publicId, agentInfo)
admin.agents[publicId] = agent admin.agents[publicId] = agent
admin.logStatus() admin.logStatus()
return agent, nil return agent, nil
@ -172,13 +173,16 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser,
userPassword comms.UserPassword) error { userPassword comms.UserPassword) error {
defer conn.Close() defer conn.Close()
err := comms.CheckProtocolVersion(comms.ConvergeServer, conn) serverInfo := comms.ServerInfo{
UserPassword: userPassword,
}
agentInfo, err := comms.ServerInitialization(conn, serverInfo)
if err != nil { if err != nil {
return err return err
} }
// TODO: remove agent return value agent, err := admin.addAgent(publicId, agentInfo, conn)
agent, err := admin.addAgent(publicId, conn)
if err != nil { if err != nil {
return err return err
} }
@ -186,9 +190,6 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser,
admin.RemoveAgent(publicId) admin.RemoveAgent(publicId)
}() }()
log.Println("Sending username and password to agent")
agent.commChannel.SideChannel.Send(userPassword)
go func() { go func() {
comms.ListenForAgentEvents(agent.commChannel.SideChannel, comms.ListenForAgentEvents(agent.commChannel.SideChannel,
func(info comms.AgentInfo) { func(info comms.AgentInfo) {