initialization of username, password on client (from server) and initialization of agentinfo on server is now done as soon as the agent registered and not through a side channel.
Making use of some simple utilities for GOB to make it easy to send objects over the line.
This commit is contained in:
parent
ada34495ef
commit
9d0675b2f2
@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/pkg/sftp"
|
"github.com/pkg/sftp"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -33,7 +32,7 @@ import (
|
|||||||
var hostPrivateKey []byte
|
var hostPrivateKey []byte
|
||||||
|
|
||||||
func SftpHandler(sess ssh.Session) {
|
func SftpHandler(sess ssh.Session) {
|
||||||
uid := int(time.Now().UnixMilli())
|
uid := int(time.Now().UnixMicro())
|
||||||
agent.Login(uid, sess)
|
agent.Login(uid, sess)
|
||||||
defer agent.LogOut(uid)
|
defer agent.LogOut(uid)
|
||||||
|
|
||||||
@ -249,8 +248,9 @@ func main() {
|
|||||||
wsConn := websocketutil.NewWebSocketConn(conn)
|
wsConn := websocketutil.NewWebSocketConn(conn)
|
||||||
defer wsConn.Close()
|
defer wsConn.Close()
|
||||||
|
|
||||||
err = comms.CheckProtocolVersion(comms.Agent, wsConn)
|
serverInfo, err := comms.AgentInitialization(wsConn, comms.NewAgentInfo())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("ERROR: %+v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -261,7 +261,10 @@ func main() {
|
|||||||
|
|
||||||
// Authentiocation
|
// Authentiocation
|
||||||
|
|
||||||
sshUserCredentials, passwordHandler, authorizedKeys := setupAuthentication(commChannel, authorizedKeysFile)
|
passwordHandler, authorizedKeys := setupAuthentication(
|
||||||
|
commChannel,
|
||||||
|
serverInfo.UserPassword,
|
||||||
|
authorizedKeysFile)
|
||||||
|
|
||||||
// Choose shell
|
// Choose shell
|
||||||
|
|
||||||
@ -279,9 +282,9 @@ func main() {
|
|||||||
log.Println()
|
log.Println()
|
||||||
clientUrl := strings.ReplaceAll(wsURL, "/agent/", "/client/")
|
clientUrl := strings.ReplaceAll(wsURL, "/agent/", "/client/")
|
||||||
sshCommand := fmt.Sprintf("ssh -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost",
|
sshCommand := fmt.Sprintf("ssh -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost",
|
||||||
clientUrl, sshUserCredentials.Username)
|
clientUrl, serverInfo.UserPassword.Username)
|
||||||
sftpCommand := fmt.Sprintf("sftp -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost",
|
sftpCommand := fmt.Sprintf("sftp -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost",
|
||||||
clientUrl, sshUserCredentials.Username)
|
clientUrl, serverInfo.UserPassword.Username)
|
||||||
log.Println(" # For SSH")
|
log.Println(" # For SSH")
|
||||||
log.Println(" " + sshCommand)
|
log.Println(" " + sshCommand)
|
||||||
log.Println()
|
log.Println()
|
||||||
@ -302,26 +305,20 @@ func main() {
|
|||||||
service.Run(listener)
|
service.Run(listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) {
|
func setupAuthentication(commChannel comms.CommChannel,
|
||||||
// Random user name and password so that effectively no one can login
|
userPassword comms.UserPassword,
|
||||||
// until the user and password have been received from the server.
|
authorizedKeysFile string) (func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) {
|
||||||
sshUserCredentials := comms.UserPassword{
|
|
||||||
Username: strconv.Itoa(rand.Int()),
|
|
||||||
Password: strconv.Itoa(rand.Int()),
|
|
||||||
}
|
|
||||||
passwordHandler := func(ctx ssh.Context, password string) bool {
|
passwordHandler := func(ctx ssh.Context, password string) bool {
|
||||||
// Replace with your own logic to validate username and password
|
// Replace with your own logic to validate username and password
|
||||||
return ctx.User() == sshUserCredentials.Username && password == sshUserCredentials.Password
|
return ctx.User() == userPassword.Username && password == userPassword.Password
|
||||||
}
|
}
|
||||||
go comms.ListenForServerEvents(commChannel, func(user comms.UserPassword) {
|
go comms.ListenForServerEvents(commChannel)
|
||||||
log.Println("Username and password configuration received from server")
|
|
||||||
sshUserCredentials = user
|
|
||||||
})
|
|
||||||
authorizedKeys := ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile)
|
authorizedKeys := ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile)
|
||||||
if len(authorizedKeys.keys) > 0 {
|
if len(authorizedKeys.keys) > 0 {
|
||||||
log.Printf("A total of %d authorized ssh keys were found", len(authorizedKeys.keys))
|
log.Printf("A total of %d authorized ssh keys were found", len(authorizedKeys.keys))
|
||||||
}
|
}
|
||||||
return sshUserCredentials, passwordHandler, authorizedKeys
|
return passwordHandler, authorizedKeys
|
||||||
}
|
}
|
||||||
|
|
||||||
func chooseShell() string {
|
func chooseShell() string {
|
||||||
|
@ -101,8 +101,10 @@ func ConfigureAgent(commChannel comms.CommChannel,
|
|||||||
log.Printf("Agent expires at %s",
|
log.Printf("Agent expires at %s",
|
||||||
state.expiryTime(holdFilename).Format(time.DateTime))
|
state.expiryTime(holdFilename).Format(time.DateTime))
|
||||||
|
|
||||||
state.commChannel.SideChannel.Send(comms.NewAgentInfo())
|
comms.Send(state.commChannel.SideChannel,
|
||||||
state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
|
comms.ConvergeMessage{
|
||||||
|
Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)),
|
||||||
|
})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
@ -182,7 +184,10 @@ func login(sessionId int, sshSession ssh.Session) {
|
|||||||
if sessionType == "" {
|
if sessionType == "" {
|
||||||
sessionType = "ssh"
|
sessionType = "ssh"
|
||||||
}
|
}
|
||||||
state.commChannel.SideChannel.Send(comms.NewSessionInfo(sessionType))
|
comms.Send(state.commChannel.SideChannel,
|
||||||
|
comms.ConvergeMessage{
|
||||||
|
Value: comms.NewSessionInfo(sessionType),
|
||||||
|
})
|
||||||
|
|
||||||
holdFileStats, ok := fileExistsWithStats(holdFilename)
|
holdFileStats, ok := fileExistsWithStats(holdFilename)
|
||||||
if ok {
|
if ok {
|
||||||
@ -297,7 +302,10 @@ func holdFileChange() {
|
|||||||
message += holdFileMessage()
|
message += holdFileMessage()
|
||||||
messageUsers(message)
|
messageUsers(message)
|
||||||
state.lastExpiryTimmeReported = newExpiryTIme
|
state.lastExpiryTimmeReported = newExpiryTIme
|
||||||
state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
|
comms.Send(state.commChannel.SideChannel,
|
||||||
|
comms.ConvergeMessage{
|
||||||
|
Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const MESSAGE_TIMEOUT = 10 * time.Second
|
||||||
|
|
||||||
type CommChannel struct {
|
type CommChannel struct {
|
||||||
// a separet connection outside of the ssh session
|
// a separet connection outside of the ssh session
|
||||||
SideChannel GOBChannel
|
SideChannel GOBChannel
|
||||||
@ -70,7 +72,10 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
|||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
time.Sleep(10 * time.Second)
|
time.Sleep(10 * time.Second)
|
||||||
err := commChannel.SideChannel.Send(HeartBeat{})
|
err := Send(commChannel.SideChannel,
|
||||||
|
ConvergeMessage{
|
||||||
|
Value: HeartBeat{},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Sending heartbeat to server failed")
|
log.Println("Sending heartbeat to server failed")
|
||||||
}
|
}
|
||||||
@ -113,13 +118,12 @@ func ListenForAgentEvents(channel GOBChannel,
|
|||||||
// intended to keep the connection up
|
// intended to keep the connection up
|
||||||
|
|
||||||
default:
|
default:
|
||||||
fmt.Printf(" Unknown type: %T\n", v)
|
fmt.Printf(" Unknown type: %v %T\n", v, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenForServerEvents(channel CommChannel,
|
func ListenForServerEvents(channel CommChannel) {
|
||||||
setUsernamePassword func(user UserPassword)) {
|
|
||||||
for {
|
for {
|
||||||
var result ConvergeMessage
|
var result ConvergeMessage
|
||||||
err := channel.SideChannel.Decoder.Decode(&result)
|
err := channel.SideChannel.Decoder.Decode(&result)
|
||||||
@ -129,10 +133,9 @@ func ListenForServerEvents(channel CommChannel,
|
|||||||
log.Printf("Exiting agent listener %v", err)
|
log.Printf("Exiting agent listener %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch v := result.Value.(type) {
|
|
||||||
|
|
||||||
case UserPassword:
|
// no supported server events at this time.
|
||||||
setUsernamePassword(v)
|
switch v := result.Value.(type) {
|
||||||
|
|
||||||
default:
|
default:
|
||||||
fmt.Printf(" Unknown type: %T\n", v)
|
fmt.Printf(" Unknown type: %T\n", v)
|
||||||
@ -140,25 +143,102 @@ func ListenForServerEvents(channel CommChannel,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckProtocolVersion(role Role, conn io.ReadWriter) error {
|
func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, error) {
|
||||||
|
channel := NewGOBChannel(conn)
|
||||||
|
err := CheckProtocolVersion(Agent, channel)
|
||||||
|
|
||||||
|
err = SendWithTimeout(channel, agentInto)
|
||||||
|
if err != nil {
|
||||||
|
return ServerInfo{}, nil
|
||||||
|
}
|
||||||
|
serverInfo, err := ReceiveWithTimeout[ServerInfo](channel)
|
||||||
|
if err != nil {
|
||||||
|
return ServerInfo{}, nil
|
||||||
|
}
|
||||||
|
// TODO remove logging
|
||||||
|
log.Println("Server info received: ", serverInfo)
|
||||||
|
|
||||||
|
return serverInfo, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo, error) {
|
||||||
|
channel := NewGOBChannel(conn)
|
||||||
|
err := CheckProtocolVersion(ConvergeServer, channel)
|
||||||
|
|
||||||
|
agentInfo, err := ReceiveWithTimeout[AgentInfo](channel)
|
||||||
|
if err != nil {
|
||||||
|
return AgentInfo{}, err
|
||||||
|
}
|
||||||
|
log.Println("Agent info received: ", agentInfo)
|
||||||
|
err = SendWithTimeout(channel, serverInfo)
|
||||||
|
if err != nil {
|
||||||
|
return AgentInfo{}, nil
|
||||||
|
}
|
||||||
|
return agentInfo, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events sent over the websocket connection that is established between
|
||||||
|
// agent and converge server. This is done as soon as the agent starts.
|
||||||
|
|
||||||
|
// First commmunication between agent and Converge Server
|
||||||
|
// Both exchange their protocol version and if it is incorrect, the session
|
||||||
|
// is terminated.
|
||||||
|
|
||||||
|
func CheckProtocolVersion(role Role, channel GOBChannel) error {
|
||||||
|
log.Println("ROLE ", role)
|
||||||
|
switch role {
|
||||||
|
case Agent:
|
||||||
|
err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if version.Version != PROTOCOL_VERSION {
|
||||||
|
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
|
||||||
|
PROTOCOL_VERSION, version.Version)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case ConvergeServer:
|
||||||
|
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if version.Version != PROTOCOL_VERSION {
|
||||||
|
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
|
||||||
|
PROTOCOL_VERSION, version.Version)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("unexpected rolg %v", role))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
|
||||||
channel := NewGOBChannel(conn)
|
channel := NewGOBChannel(conn)
|
||||||
|
|
||||||
sends := make(chan any)
|
sends := make(chan bool)
|
||||||
receives := make(chan any)
|
receives := make(chan ProtocolVersion)
|
||||||
errors := make(chan error)
|
errors := make(chan error)
|
||||||
|
|
||||||
channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
|
SendAsync(channel, ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
|
||||||
channel.ReceiveAsync(receives, errors)
|
ReceiveAsync(channel, receives, errors)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-time.After(10 * time.Second):
|
case <-time.After(MESSAGE_TIMEOUT):
|
||||||
log.Println("PROTOCOLVERSION: timeout")
|
log.Println("PROTOCOLVERSION: timeout")
|
||||||
return fmt.Errorf("Timeout waiting for protocol version")
|
return fmt.Errorf("Timeout waiting for protocol version")
|
||||||
case err := <-errors:
|
case err := <-errors:
|
||||||
log.Printf("PROTOCOLVERSION: %v", err)
|
log.Printf("PROTOCOLVERSION: %v", err)
|
||||||
return err
|
return err
|
||||||
case protocolVersion := <-receives:
|
case protocolVersion := <-receives:
|
||||||
otherVersion := protocolVersion.(ProtocolVersion).Version
|
otherVersion := protocolVersion.Version
|
||||||
if PROTOCOL_VERSION != otherVersion {
|
if PROTOCOL_VERSION != otherVersion {
|
||||||
switch role {
|
switch role {
|
||||||
case Agent:
|
case Agent:
|
||||||
@ -170,7 +250,12 @@ func CheckProtocolVersion(role Role, conn io.ReadWriter) error {
|
|||||||
}
|
}
|
||||||
return fmt.Errorf("Protocol version mismatch")
|
return fmt.Errorf("Protocol version mismatch")
|
||||||
}
|
}
|
||||||
log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version)
|
log.Printf("PROTOCOLVERSION: %v", protocolVersion.Version)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Session info metadata exchange. These are sent over the SSH connection. The agent embedded
|
||||||
|
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
|
||||||
|
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
|
||||||
|
// metadata before the connection is handed back to SSH.
|
||||||
|
@ -23,6 +23,10 @@ type AgentInfo struct {
|
|||||||
OS string
|
OS string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ClientInfo struct {
|
||||||
|
ClientId string
|
||||||
|
}
|
||||||
|
|
||||||
type SessionInfo struct {
|
type SessionInfo struct {
|
||||||
// "ssh", "sftp"
|
// "ssh", "sftp"
|
||||||
SessionType string
|
SessionType string
|
||||||
@ -47,8 +51,7 @@ type UserPassword struct {
|
|||||||
Password string
|
Password string
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConnectionInfo struct {
|
type ServerInfo struct {
|
||||||
ConnectionId int
|
|
||||||
UserPassword UserPassword
|
UserPassword UserPassword
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,7 +91,6 @@ func RegisterEventsWithGob() {
|
|||||||
// ConvergeServer to Agent
|
// ConvergeServer to Agent
|
||||||
gob.Register(ProtocolVersion{})
|
gob.Register(ProtocolVersion{})
|
||||||
gob.Register(UserPassword{})
|
gob.Register(UserPassword{})
|
||||||
gob.Register(ConnectionInfo{})
|
|
||||||
|
|
||||||
// Wrapper event.
|
// Wrapper event.
|
||||||
gob.Register(ConvergeMessage{})
|
gob.Register(ConvergeMessage{})
|
||||||
|
@ -2,8 +2,10 @@ package comms
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GOBChannel struct {
|
type GOBChannel struct {
|
||||||
@ -22,12 +24,26 @@ func NewGOBChannel(conn io.ReadWriter) GOBChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Send(channel GOBChannel, object any) error {
|
||||||
|
err := channel.Encoder.Encode(object)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Encoding error %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func Receive[T any](channel GOBChannel) (T, error) {
|
||||||
|
target := *new(T)
|
||||||
|
err := channel.Decoder.Decode(&target)
|
||||||
|
return target, err
|
||||||
|
}
|
||||||
|
|
||||||
// Asynchronous send and receive on a single connection is guaranteed to preserver ordering of
|
// Asynchronous send and receive on a single connection is guaranteed to preserver ordering of
|
||||||
// messages. We use asynchronous to void blocking indefinitely or depending on network timeouts.
|
// messages. We use asynchronous to void blocking indefinitely or depending on network timeouts.
|
||||||
|
|
||||||
func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- error) {
|
func SendAsync[T any](channel GOBChannel, obj T, done chan<- bool, errors chan<- error) {
|
||||||
go func() {
|
go func() {
|
||||||
err := channel.Send(obj)
|
err := Send(channel, obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors <- err
|
errors <- err
|
||||||
} else {
|
} else {
|
||||||
@ -36,9 +52,9 @@ func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- erro
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) {
|
func ReceiveAsync[T any](channel GOBChannel, result chan T, errors chan<- error) {
|
||||||
go func() {
|
go func() {
|
||||||
value, err := channel.Receive()
|
value, err := Receive[T](channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors <- err
|
errors <- err
|
||||||
} else {
|
} else {
|
||||||
@ -47,16 +63,32 @@ func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel GOBChannel) Send(object any) error {
|
func SendWithTimeout[T any](channel GOBChannel, obj T) error {
|
||||||
err := channel.Encoder.Encode(ConvergeMessage{Value: object})
|
done := make(chan bool)
|
||||||
if err != nil {
|
errors := make(chan error)
|
||||||
log.Printf("Encoding error %v", err)
|
|
||||||
|
SendAsync(channel, obj, done, errors)
|
||||||
|
select {
|
||||||
|
case <-time.After(MESSAGE_TIMEOUT):
|
||||||
|
return fmt.Errorf("Timeout in SwndWithTimout")
|
||||||
|
case err := <-errors:
|
||||||
|
return err
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel GOBChannel) Receive() (any, error) {
|
func ReceiveWithTimeout[T any](channel GOBChannel) (T, error) {
|
||||||
var target ConvergeMessage
|
result := make(chan T)
|
||||||
err := channel.Decoder.Decode(&target)
|
errors := make(chan error)
|
||||||
return target.Value, err
|
|
||||||
|
ReceiveAsync(channel, result, errors)
|
||||||
|
select {
|
||||||
|
case <-time.After(MESSAGE_TIMEOUT):
|
||||||
|
return *new(T), fmt.Errorf("Timeout in ReceiveWithTimout")
|
||||||
|
case err := <-errors:
|
||||||
|
return *new(T), err
|
||||||
|
case value := <-result:
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,11 +29,12 @@ type Client struct {
|
|||||||
sessionType string
|
sessionType string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAgent(commChannel comms.CommChannel, publicId string) *Agent {
|
func NewAgent(commChannel comms.CommChannel, publicId string, agentInfo comms.AgentInfo) *Agent {
|
||||||
return &Agent{
|
return &Agent{
|
||||||
commChannel: commChannel,
|
commChannel: commChannel,
|
||||||
publicId: publicId,
|
publicId: publicId,
|
||||||
startTime: time.Now(),
|
startTime: time.Now(),
|
||||||
|
agentInfo: agentInfo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,7 +89,7 @@ func (admin *Admin) logStatus() {
|
|||||||
log.Printf("\n")
|
log.Printf("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (admin *Admin) addAgent(publicId string, conn io.ReadWriteCloser) (*Agent, error) {
|
func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io.ReadWriteCloser) (*Agent, error) {
|
||||||
admin.mutex.Lock()
|
admin.mutex.Lock()
|
||||||
defer admin.mutex.Unlock()
|
defer admin.mutex.Unlock()
|
||||||
|
|
||||||
@ -102,7 +103,7 @@ func (admin *Admin) addAgent(publicId string, conn io.ReadWriteCloser) (*Agent,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
agent = NewAgent(commChannel, publicId)
|
agent = NewAgent(commChannel, publicId, agentInfo)
|
||||||
admin.agents[publicId] = agent
|
admin.agents[publicId] = agent
|
||||||
admin.logStatus()
|
admin.logStatus()
|
||||||
return agent, nil
|
return agent, nil
|
||||||
@ -172,13 +173,16 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser,
|
|||||||
userPassword comms.UserPassword) error {
|
userPassword comms.UserPassword) error {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
err := comms.CheckProtocolVersion(comms.ConvergeServer, conn)
|
serverInfo := comms.ServerInfo{
|
||||||
|
UserPassword: userPassword,
|
||||||
|
}
|
||||||
|
|
||||||
|
agentInfo, err := comms.ServerInitialization(conn, serverInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove agent return value
|
agent, err := admin.addAgent(publicId, agentInfo, conn)
|
||||||
agent, err := admin.addAgent(publicId, conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -186,9 +190,6 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser,
|
|||||||
admin.RemoveAgent(publicId)
|
admin.RemoveAgent(publicId)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Println("Sending username and password to agent")
|
|
||||||
agent.commChannel.SideChannel.Send(userPassword)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
comms.ListenForAgentEvents(agent.commChannel.SideChannel,
|
comms.ListenForAgentEvents(agent.commChannel.SideChannel,
|
||||||
func(info comms.AgentInfo) {
|
func(info comms.AgentInfo) {
|
||||||
|
Loading…
Reference in New Issue
Block a user