diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 31d7ec0..7860403 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -159,8 +159,6 @@ func main() { } wsURL := flag.Arg(0) - agent.ConfigureAgent(*advanceWarningTime, *agentExpriryTime, *tickerInterval) - dialer := websocket.Dialer{ Proxy: http.ProxyFromEnvironment, HandshakeTimeout: 45 * time.Second, @@ -226,5 +224,7 @@ func main() { strings.ReplaceAll(urlObject.Scheme, "ws", "http")+ "://"+urlObject.Host+"/docs/wsproxy") log.Println() + + agent.ConfigureAgent(commChannel, *advanceWarningTime, *agentExpriryTime, *tickerInterval) service.Run(commChannel.Session) } diff --git a/pkg/agent/session.go b/pkg/agent/session.go index 600b61d..77f031a 100644 --- a/pkg/agent/session.go +++ b/pkg/agent/session.go @@ -3,6 +3,7 @@ package agent import ( "bytes" "converge/pkg/async" + "converge/pkg/comms" "fmt" "github.com/fsnotify/fsnotify" "github.com/gliderlabs/ssh" @@ -31,7 +32,8 @@ import ( // global configuration type AgentState struct { - startTime time.Time + commChannel comms.CommChannel + startTime time.Time // Advance warning time to notify the user of something important happening advanceWarningTime time.Duration @@ -73,7 +75,8 @@ var events = make(chan func(), 10) // External interface, asynchronous, apart from the initialization. -func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) { +func ConfigureAgent(commChannel comms.CommChannel, + advanceWarningTime, agentExpiryTime, tickerInterval time.Duration) { if fileExists(holdFilename) { log.Printf("Removing hold file '%s'", holdFilename) err := os.Remove(holdFilename) @@ -82,6 +85,7 @@ func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Dur } } state = AgentState{ + commChannel: commChannel, startTime: time.Now(), advanceWarningTime: advanceWarningTime, agentExpriryTime: agentExpiryTime, @@ -97,6 +101,9 @@ func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Dur log.Printf("Agent expires at %s", state.expiryTime(holdFilename).Format(time.DateTime)) + comms.SendSessionInfo(state.commChannel) + comms.SendExpiryTimeUpdate(state.commChannel, state.expiryTime(holdFilename)) + go func() { for { <-state.ticker.C @@ -284,6 +291,7 @@ func holdFileChange() { message += holdFileMessage() messageUsers(message) state.lastExpiryTimmeReported = newExpiryTIme + comms.SendExpiryTimeUpdate(state.commChannel, state.lastExpiryTimmeReported) } } diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index 8ec4f45..ddc5e81 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -71,7 +71,7 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { } log.Println("Communication channel between agent and converge server established") - gob.Register(RemoteSession{}) + gob.Register(SessionInfo{}) gob.Register(ExpiryTimeUpdate{}) gob.Register(ConvergeMessage{}) @@ -79,15 +79,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { commChannel.Decoder = gob.NewDecoder(commChannel.Peer) switch role { - case Agent: - err := commChannel.Encoder.Encode(ConvergeMessage{Value: NewRemoteSession()}) - if err != nil { - log.Printf("Encoding error %v", err) - } - err = commChannel.Encoder.Encode(ConvergeMessage{Value: NewExpiryTimeUpdate(time.Now())}) - if err != nil { - log.Printf("Encoding error %v", err) - } case ConvergeServer: go serverReader(commChannel) } @@ -101,10 +92,12 @@ func serverReader(channel CommChannel) { err := channel.Decoder.Decode(&result) if err != nil { - log.Printf("Error decoding object %v", err) + // TODO more clean solution, need to explicitly close when agent exits. + log.Printf("Exiting serverReader %v", err) + return } switch v := result.Value.(type) { - case RemoteSession: + case SessionInfo: log.Println("RECEIVED: session info ", v) case ExpiryTimeUpdate: log.Println("RECEIVED: expirytime update ", v) @@ -114,17 +107,17 @@ func serverReader(channel CommChannel) { } } -type RemoteSession struct { +type SessionInfo struct { Username string Hostname string Pwd string } -func NewRemoteSession() RemoteSession { +func NewSessionInfo() SessionInfo { username, _ := user.Current() host, _ := os.Hostname() pwd, _ := os.Getwd() - return RemoteSession{ + return SessionInfo{ Username: username.Username, Hostname: host, Pwd: pwd, diff --git a/pkg/comms/serverapi.go b/pkg/comms/serverapi.go new file mode 100644 index 0000000..e350817 --- /dev/null +++ b/pkg/comms/serverapi.go @@ -0,0 +1,20 @@ +package comms + +import ( + "log" + "time" +) + +func SendSessionInfo(commChannel CommChannel) { + err := commChannel.Encoder.Encode(ConvergeMessage{Value: NewSessionInfo()}) + if err != nil { + log.Printf("Encoding error %v", err) + } +} + +func SendExpiryTimeUpdate(commChannel CommChannel, expiryTime time.Time) { + err := commChannel.Encoder.Encode(ConvergeMessage{Value: NewExpiryTimeUpdate(expiryTime)}) + if err != nil { + log.Printf("Encoding error %v", err) + } +}