From 9d0675b2f2cd0ca0ba10e49e7ac0fd81cf03fe94 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Sat, 27 Jul 2024 20:46:53 +0200 Subject: [PATCH] initialization of username, password on client (from server) and initialization of agentinfo on server is now done as soon as the agent registered and not through a side channel. Making use of some simple utilities for GOB to make it easy to send objects over the line. --- cmd/agent/agent.go | 35 ++++++------ pkg/agent/session.go | 16 ++++-- pkg/comms/agentserver.go | 115 ++++++++++++++++++++++++++++++++++----- pkg/comms/events.go | 8 ++- pkg/comms/gobchannel.go | 58 +++++++++++++++----- pkg/converge/admin.go | 19 ++++--- 6 files changed, 188 insertions(+), 63 deletions(-) diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 00bef69..53b354a 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -14,7 +14,6 @@ import ( "github.com/pkg/sftp" "io" "log" - "math/rand" "net" "net/http" "net/url" @@ -33,7 +32,7 @@ import ( var hostPrivateKey []byte func SftpHandler(sess ssh.Session) { - uid := int(time.Now().UnixMilli()) + uid := int(time.Now().UnixMicro()) agent.Login(uid, sess) defer agent.LogOut(uid) @@ -249,8 +248,9 @@ func main() { wsConn := websocketutil.NewWebSocketConn(conn) defer wsConn.Close() - err = comms.CheckProtocolVersion(comms.Agent, wsConn) + serverInfo, err := comms.AgentInitialization(wsConn, comms.NewAgentInfo()) if err != nil { + log.Printf("ERROR: %+v", err) os.Exit(1) } @@ -261,7 +261,10 @@ func main() { // Authentiocation - sshUserCredentials, passwordHandler, authorizedKeys := setupAuthentication(commChannel, authorizedKeysFile) + passwordHandler, authorizedKeys := setupAuthentication( + commChannel, + serverInfo.UserPassword, + authorizedKeysFile) // Choose shell @@ -279,9 +282,9 @@ func main() { log.Println() clientUrl := strings.ReplaceAll(wsURL, "/agent/", "/client/") sshCommand := fmt.Sprintf("ssh -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost", - clientUrl, sshUserCredentials.Username) + clientUrl, serverInfo.UserPassword.Username) sftpCommand := fmt.Sprintf("sftp -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost", - clientUrl, sshUserCredentials.Username) + clientUrl, serverInfo.UserPassword.Username) log.Println(" # For SSH") log.Println(" " + sshCommand) log.Println() @@ -302,26 +305,20 @@ func main() { service.Run(listener) } -func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) { - // Random user name and password so that effectively no one can login - // until the user and password have been received from the server. - sshUserCredentials := comms.UserPassword{ - Username: strconv.Itoa(rand.Int()), - Password: strconv.Itoa(rand.Int()), - } +func setupAuthentication(commChannel comms.CommChannel, + userPassword comms.UserPassword, + authorizedKeysFile string) (func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) { + passwordHandler := func(ctx ssh.Context, password string) bool { // Replace with your own logic to validate username and password - return ctx.User() == sshUserCredentials.Username && password == sshUserCredentials.Password + return ctx.User() == userPassword.Username && password == userPassword.Password } - go comms.ListenForServerEvents(commChannel, func(user comms.UserPassword) { - log.Println("Username and password configuration received from server") - sshUserCredentials = user - }) + go comms.ListenForServerEvents(commChannel) authorizedKeys := ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile) if len(authorizedKeys.keys) > 0 { log.Printf("A total of %d authorized ssh keys were found", len(authorizedKeys.keys)) } - return sshUserCredentials, passwordHandler, authorizedKeys + return passwordHandler, authorizedKeys } func chooseShell() string { diff --git a/pkg/agent/session.go b/pkg/agent/session.go index f195308..2a6cdcf 100644 --- a/pkg/agent/session.go +++ b/pkg/agent/session.go @@ -101,8 +101,10 @@ func ConfigureAgent(commChannel comms.CommChannel, log.Printf("Agent expires at %s", state.expiryTime(holdFilename).Format(time.DateTime)) - state.commChannel.SideChannel.Send(comms.NewAgentInfo()) - state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) + comms.Send(state.commChannel.SideChannel, + comms.ConvergeMessage{ + Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)), + }) go func() { for { @@ -182,7 +184,10 @@ func login(sessionId int, sshSession ssh.Session) { if sessionType == "" { sessionType = "ssh" } - state.commChannel.SideChannel.Send(comms.NewSessionInfo(sessionType)) + comms.Send(state.commChannel.SideChannel, + comms.ConvergeMessage{ + Value: comms.NewSessionInfo(sessionType), + }) holdFileStats, ok := fileExistsWithStats(holdFilename) if ok { @@ -297,7 +302,10 @@ func holdFileChange() { message += holdFileMessage() messageUsers(message) state.lastExpiryTimmeReported = newExpiryTIme - state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) + comms.Send(state.commChannel.SideChannel, + comms.ConvergeMessage{ + Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)), + }) } } diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index 3ad5784..593b447 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -8,6 +8,8 @@ import ( "time" ) +const MESSAGE_TIMEOUT = 10 * time.Second + type CommChannel struct { // a separet connection outside of the ssh session SideChannel GOBChannel @@ -70,7 +72,10 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { go func() { for { time.Sleep(10 * time.Second) - err := commChannel.SideChannel.Send(HeartBeat{}) + err := Send(commChannel.SideChannel, + ConvergeMessage{ + Value: HeartBeat{}, + }) if err != nil { log.Println("Sending heartbeat to server failed") } @@ -113,13 +118,12 @@ func ListenForAgentEvents(channel GOBChannel, // intended to keep the connection up default: - fmt.Printf(" Unknown type: %T\n", v) + fmt.Printf(" Unknown type: %v %T\n", v, v) } } } -func ListenForServerEvents(channel CommChannel, - setUsernamePassword func(user UserPassword)) { +func ListenForServerEvents(channel CommChannel) { for { var result ConvergeMessage err := channel.SideChannel.Decoder.Decode(&result) @@ -129,10 +133,9 @@ func ListenForServerEvents(channel CommChannel, log.Printf("Exiting agent listener %v", err) return } - switch v := result.Value.(type) { - case UserPassword: - setUsernamePassword(v) + // no supported server events at this time. + switch v := result.Value.(type) { default: fmt.Printf(" Unknown type: %T\n", v) @@ -140,25 +143,102 @@ func ListenForServerEvents(channel CommChannel, } } -func CheckProtocolVersion(role Role, conn io.ReadWriter) error { +func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, error) { + channel := NewGOBChannel(conn) + err := CheckProtocolVersion(Agent, channel) + + err = SendWithTimeout(channel, agentInto) + if err != nil { + return ServerInfo{}, nil + } + serverInfo, err := ReceiveWithTimeout[ServerInfo](channel) + if err != nil { + return ServerInfo{}, nil + } + // TODO remove logging + log.Println("Server info received: ", serverInfo) + + return serverInfo, err +} + +func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo, error) { + channel := NewGOBChannel(conn) + err := CheckProtocolVersion(ConvergeServer, channel) + + agentInfo, err := ReceiveWithTimeout[AgentInfo](channel) + if err != nil { + return AgentInfo{}, err + } + log.Println("Agent info received: ", agentInfo) + err = SendWithTimeout(channel, serverInfo) + if err != nil { + return AgentInfo{}, nil + } + return agentInfo, err +} + +// Events sent over the websocket connection that is established between +// agent and converge server. This is done as soon as the agent starts. + +// First commmunication between agent and Converge Server +// Both exchange their protocol version and if it is incorrect, the session +// is terminated. + +func CheckProtocolVersion(role Role, channel GOBChannel) error { + log.Println("ROLE ", role) + switch role { + case Agent: + err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION}) + if err != nil { + return err + } + version, err := ReceiveWithTimeout[ProtocolVersion](channel) + if err != nil { + return err + } + if version.Version != PROTOCOL_VERSION { + return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d", + PROTOCOL_VERSION, version.Version) + } + return nil + case ConvergeServer: + version, err := ReceiveWithTimeout[ProtocolVersion](channel) + if err != nil { + return err + } + err = SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION}) + if err != nil { + return err + } + if version.Version != PROTOCOL_VERSION { + return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d", + PROTOCOL_VERSION, version.Version) + } + return nil + default: + panic(fmt.Errorf("unexpected rolg %v", role)) + } +} + +func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error { channel := NewGOBChannel(conn) - sends := make(chan any) - receives := make(chan any) + sends := make(chan bool) + receives := make(chan ProtocolVersion) errors := make(chan error) - channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors) - channel.ReceiveAsync(receives, errors) + SendAsync(channel, ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors) + ReceiveAsync(channel, receives, errors) select { - case <-time.After(10 * time.Second): + case <-time.After(MESSAGE_TIMEOUT): log.Println("PROTOCOLVERSION: timeout") return fmt.Errorf("Timeout waiting for protocol version") case err := <-errors: log.Printf("PROTOCOLVERSION: %v", err) return err case protocolVersion := <-receives: - otherVersion := protocolVersion.(ProtocolVersion).Version + otherVersion := protocolVersion.Version if PROTOCOL_VERSION != otherVersion { switch role { case Agent: @@ -170,7 +250,12 @@ func CheckProtocolVersion(role Role, conn io.ReadWriter) error { } return fmt.Errorf("Protocol version mismatch") } - log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version) + log.Printf("PROTOCOLVERSION: %v", protocolVersion.Version) return nil } } + +// Session info metadata exchange. These are sent over the SSH connection. The agent embedded +// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that +// decorates the yamux Session (which is a listener) and uses this connection to exchange some +// metadata before the connection is handed back to SSH. diff --git a/pkg/comms/events.go b/pkg/comms/events.go index 3147945..dfcc222 100644 --- a/pkg/comms/events.go +++ b/pkg/comms/events.go @@ -23,6 +23,10 @@ type AgentInfo struct { OS string } +type ClientInfo struct { + ClientId string +} + type SessionInfo struct { // "ssh", "sftp" SessionType string @@ -47,8 +51,7 @@ type UserPassword struct { Password string } -type ConnectionInfo struct { - ConnectionId int +type ServerInfo struct { UserPassword UserPassword } @@ -88,7 +91,6 @@ func RegisterEventsWithGob() { // ConvergeServer to Agent gob.Register(ProtocolVersion{}) gob.Register(UserPassword{}) - gob.Register(ConnectionInfo{}) // Wrapper event. gob.Register(ConvergeMessage{}) diff --git a/pkg/comms/gobchannel.go b/pkg/comms/gobchannel.go index d5ed654..854112c 100644 --- a/pkg/comms/gobchannel.go +++ b/pkg/comms/gobchannel.go @@ -2,8 +2,10 @@ package comms import ( "encoding/gob" + "fmt" "io" "log" + "time" ) type GOBChannel struct { @@ -22,12 +24,26 @@ func NewGOBChannel(conn io.ReadWriter) GOBChannel { } } +func Send(channel GOBChannel, object any) error { + err := channel.Encoder.Encode(object) + if err != nil { + log.Printf("Encoding error %v", err) + } + return err +} + +func Receive[T any](channel GOBChannel) (T, error) { + target := *new(T) + err := channel.Decoder.Decode(&target) + return target, err +} + // Asynchronous send and receive on a single connection is guaranteed to preserver ordering of // messages. We use asynchronous to void blocking indefinitely or depending on network timeouts. -func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- error) { +func SendAsync[T any](channel GOBChannel, obj T, done chan<- bool, errors chan<- error) { go func() { - err := channel.Send(obj) + err := Send(channel, obj) if err != nil { errors <- err } else { @@ -36,9 +52,9 @@ func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- erro }() } -func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) { +func ReceiveAsync[T any](channel GOBChannel, result chan T, errors chan<- error) { go func() { - value, err := channel.Receive() + value, err := Receive[T](channel) if err != nil { errors <- err } else { @@ -47,16 +63,32 @@ func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) { }() } -func (channel GOBChannel) Send(object any) error { - err := channel.Encoder.Encode(ConvergeMessage{Value: object}) - if err != nil { - log.Printf("Encoding error %v", err) +func SendWithTimeout[T any](channel GOBChannel, obj T) error { + done := make(chan bool) + errors := make(chan error) + + SendAsync(channel, obj, done, errors) + select { + case <-time.After(MESSAGE_TIMEOUT): + return fmt.Errorf("Timeout in SwndWithTimout") + case err := <-errors: + return err + case <-done: + return nil } - return err } -func (channel GOBChannel) Receive() (any, error) { - var target ConvergeMessage - err := channel.Decoder.Decode(&target) - return target.Value, err +func ReceiveWithTimeout[T any](channel GOBChannel) (T, error) { + result := make(chan T) + errors := make(chan error) + + ReceiveAsync(channel, result, errors) + select { + case <-time.After(MESSAGE_TIMEOUT): + return *new(T), fmt.Errorf("Timeout in ReceiveWithTimout") + case err := <-errors: + return *new(T), err + case value := <-result: + return value, nil + } } diff --git a/pkg/converge/admin.go b/pkg/converge/admin.go index 1489377..dce80b6 100644 --- a/pkg/converge/admin.go +++ b/pkg/converge/admin.go @@ -29,11 +29,12 @@ type Client struct { sessionType string } -func NewAgent(commChannel comms.CommChannel, publicId string) *Agent { +func NewAgent(commChannel comms.CommChannel, publicId string, agentInfo comms.AgentInfo) *Agent { return &Agent{ commChannel: commChannel, publicId: publicId, startTime: time.Now(), + agentInfo: agentInfo, } } @@ -88,7 +89,7 @@ func (admin *Admin) logStatus() { log.Printf("\n") } -func (admin *Admin) addAgent(publicId string, conn io.ReadWriteCloser) (*Agent, error) { +func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io.ReadWriteCloser) (*Agent, error) { admin.mutex.Lock() defer admin.mutex.Unlock() @@ -102,7 +103,7 @@ func (admin *Admin) addAgent(publicId string, conn io.ReadWriteCloser) (*Agent, if err != nil { return nil, err } - agent = NewAgent(commChannel, publicId) + agent = NewAgent(commChannel, publicId, agentInfo) admin.agents[publicId] = agent admin.logStatus() return agent, nil @@ -172,13 +173,16 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser, userPassword comms.UserPassword) error { defer conn.Close() - err := comms.CheckProtocolVersion(comms.ConvergeServer, conn) + serverInfo := comms.ServerInfo{ + UserPassword: userPassword, + } + + agentInfo, err := comms.ServerInitialization(conn, serverInfo) if err != nil { return err } - // TODO: remove agent return value - agent, err := admin.addAgent(publicId, conn) + agent, err := admin.addAgent(publicId, agentInfo, conn) if err != nil { return err } @@ -186,9 +190,6 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser, admin.RemoveAgent(publicId) }() - log.Println("Sending username and password to agent") - agent.commChannel.SideChannel.Send(userPassword) - go func() { comms.ListenForAgentEvents(agent.commChannel.SideChannel, func(info comms.AgentInfo) {