Making use of some simple utilities for GOB to make it easy to send objects over the line.
262 lines
6.5 KiB
Go
262 lines
6.5 KiB
Go
package comms
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/hashicorp/yamux"
|
|
"io"
|
|
"log"
|
|
"time"
|
|
)
|
|
|
|
const MESSAGE_TIMEOUT = 10 * time.Second
|
|
|
|
type CommChannel struct {
|
|
// a separet connection outside of the ssh session
|
|
SideChannel GOBChannel
|
|
Session *yamux.Session
|
|
}
|
|
|
|
type Role int
|
|
|
|
const (
|
|
Agent Role = iota
|
|
ConvergeServer
|
|
)
|
|
|
|
func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
|
var commChannel CommChannel
|
|
switch role {
|
|
case Agent:
|
|
listener, err := yamux.Server(wsConn, nil)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
commChannel = CommChannel{
|
|
Session: listener,
|
|
}
|
|
case ConvergeServer:
|
|
clientSession, err := yamux.Client(wsConn, nil)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
commChannel = CommChannel{
|
|
Session: clientSession,
|
|
}
|
|
default:
|
|
panic(fmt.Errorf("Undefined role %d", role))
|
|
}
|
|
|
|
// communication between Agent and ConvergeServer
|
|
// Currently used only fof communication from Agent to ConvergeServer
|
|
|
|
switch role {
|
|
case Agent:
|
|
conn, err := commChannel.Session.OpenStream()
|
|
commChannel.SideChannel = NewGOBChannel(conn)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
case ConvergeServer:
|
|
conn, err := commChannel.Session.Accept()
|
|
commChannel.SideChannel = NewGOBChannel(conn)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
default:
|
|
panic(fmt.Errorf("Undefined role %d", role))
|
|
}
|
|
log.Println("Communication channel between agent and converge server established")
|
|
|
|
// heartbeat
|
|
if role == Agent {
|
|
go func() {
|
|
for {
|
|
time.Sleep(10 * time.Second)
|
|
err := Send(commChannel.SideChannel,
|
|
ConvergeMessage{
|
|
Value: HeartBeat{},
|
|
})
|
|
if err != nil {
|
|
log.Println("Sending heartbeat to server failed")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
return commChannel, nil
|
|
}
|
|
|
|
// Sending an event to the other side
|
|
|
|
func ListenForAgentEvents(channel GOBChannel,
|
|
agentInfo func(agent AgentInfo),
|
|
sessionInfo func(session SessionInfo),
|
|
expiryTimeUpdate func(session ExpiryTimeUpdate)) {
|
|
for {
|
|
var result ConvergeMessage
|
|
err := channel.Decoder.Decode(&result)
|
|
|
|
if err != nil {
|
|
// TODO more clean solution, need to explicitly close when agent exits.
|
|
log.Printf("Exiting agent listener %v", err)
|
|
return
|
|
}
|
|
switch v := result.Value.(type) {
|
|
|
|
case AgentInfo:
|
|
agentInfo(v)
|
|
|
|
case SessionInfo:
|
|
sessionInfo(v)
|
|
|
|
case ExpiryTimeUpdate:
|
|
expiryTimeUpdate(v)
|
|
|
|
case HeartBeat:
|
|
// for not ignoring, can also implement behavior
|
|
// when heartbeat not received but hearbeat is only
|
|
// intended to keep the connection up
|
|
|
|
default:
|
|
fmt.Printf(" Unknown type: %v %T\n", v, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func ListenForServerEvents(channel CommChannel) {
|
|
for {
|
|
var result ConvergeMessage
|
|
err := channel.SideChannel.Decoder.Decode(&result)
|
|
|
|
if err != nil {
|
|
// TODO more clean solution, need to explicitly close when agent exits.
|
|
log.Printf("Exiting agent listener %v", err)
|
|
return
|
|
}
|
|
|
|
// no supported server events at this time.
|
|
switch v := result.Value.(type) {
|
|
|
|
default:
|
|
fmt.Printf(" Unknown type: %T\n", v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, error) {
|
|
channel := NewGOBChannel(conn)
|
|
err := CheckProtocolVersion(Agent, channel)
|
|
|
|
err = SendWithTimeout(channel, agentInto)
|
|
if err != nil {
|
|
return ServerInfo{}, nil
|
|
}
|
|
serverInfo, err := ReceiveWithTimeout[ServerInfo](channel)
|
|
if err != nil {
|
|
return ServerInfo{}, nil
|
|
}
|
|
// TODO remove logging
|
|
log.Println("Server info received: ", serverInfo)
|
|
|
|
return serverInfo, err
|
|
}
|
|
|
|
func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo, error) {
|
|
channel := NewGOBChannel(conn)
|
|
err := CheckProtocolVersion(ConvergeServer, channel)
|
|
|
|
agentInfo, err := ReceiveWithTimeout[AgentInfo](channel)
|
|
if err != nil {
|
|
return AgentInfo{}, err
|
|
}
|
|
log.Println("Agent info received: ", agentInfo)
|
|
err = SendWithTimeout(channel, serverInfo)
|
|
if err != nil {
|
|
return AgentInfo{}, nil
|
|
}
|
|
return agentInfo, err
|
|
}
|
|
|
|
// Events sent over the websocket connection that is established between
|
|
// agent and converge server. This is done as soon as the agent starts.
|
|
|
|
// First commmunication between agent and Converge Server
|
|
// Both exchange their protocol version and if it is incorrect, the session
|
|
// is terminated.
|
|
|
|
func CheckProtocolVersion(role Role, channel GOBChannel) error {
|
|
log.Println("ROLE ", role)
|
|
switch role {
|
|
case Agent:
|
|
err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if version.Version != PROTOCOL_VERSION {
|
|
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
|
|
PROTOCOL_VERSION, version.Version)
|
|
}
|
|
return nil
|
|
case ConvergeServer:
|
|
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if version.Version != PROTOCOL_VERSION {
|
|
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
|
|
PROTOCOL_VERSION, version.Version)
|
|
}
|
|
return nil
|
|
default:
|
|
panic(fmt.Errorf("unexpected rolg %v", role))
|
|
}
|
|
}
|
|
|
|
func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
|
|
channel := NewGOBChannel(conn)
|
|
|
|
sends := make(chan bool)
|
|
receives := make(chan ProtocolVersion)
|
|
errors := make(chan error)
|
|
|
|
SendAsync(channel, ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
|
|
ReceiveAsync(channel, receives, errors)
|
|
|
|
select {
|
|
case <-time.After(MESSAGE_TIMEOUT):
|
|
log.Println("PROTOCOLVERSION: timeout")
|
|
return fmt.Errorf("Timeout waiting for protocol version")
|
|
case err := <-errors:
|
|
log.Printf("PROTOCOLVERSION: %v", err)
|
|
return err
|
|
case protocolVersion := <-receives:
|
|
otherVersion := protocolVersion.Version
|
|
if PROTOCOL_VERSION != otherVersion {
|
|
switch role {
|
|
case Agent:
|
|
log.Printf("Protocol version mismatch: agent %d, converge server %d",
|
|
PROTOCOL_VERSION, otherVersion)
|
|
case ConvergeServer:
|
|
log.Printf("Protocol version mismatch: agent %d, converge server %d",
|
|
otherVersion, PROTOCOL_VERSION)
|
|
}
|
|
return fmt.Errorf("Protocol version mismatch")
|
|
}
|
|
log.Printf("PROTOCOLVERSION: %v", protocolVersion.Version)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Session info metadata exchange. These are sent over the SSH connection. The agent embedded
|
|
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
|
|
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
|
|
// metadata before the connection is handed back to SSH.
|