converge/pkg/comms/agentserver.go
Erik Brakkee 78e3556787 reintroduced ClientInfo because it does appear to work.
Most likely some error elsewhere caused it not to work previously
2024-09-08 11:16:49 +02:00

258 lines
6.4 KiB
Go

package comms
import (
"fmt"
"github.com/hashicorp/yamux"
"io"
"log"
"time"
)
const MESSAGE_TIMEOUT = 10 * time.Second
type CommChannel struct {
// a separate connection outside of the ssh session
SideChannel GOBChannel
Session *yamux.Session
}
type Role int
const (
Agent Role = iota
ConvergeServer
)
func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
var commChannel CommChannel
switch role {
case Agent:
listener, err := yamux.Server(wsConn, nil)
if err != nil {
return CommChannel{}, err
}
commChannel = CommChannel{
Session: listener,
}
case ConvergeServer:
clientSession, err := yamux.Client(wsConn, nil)
if err != nil {
return CommChannel{}, err
}
commChannel = CommChannel{
Session: clientSession,
}
default:
panic(fmt.Errorf("Undefined role %d", role))
}
// communication between Agent and ConvergeServer
// Currently used only fof communication from Agent to ConvergeServer
switch role {
case Agent:
conn, err := commChannel.Session.OpenStream()
commChannel.SideChannel = NewGOBChannel(conn)
if err != nil {
return CommChannel{}, err
}
case ConvergeServer:
conn, err := commChannel.Session.Accept()
commChannel.SideChannel = NewGOBChannel(conn)
if err != nil {
return CommChannel{}, err
}
default:
panic(fmt.Errorf("Undefined role %d", role))
}
log.Println("Communication channel between agent and converge server established")
return commChannel, nil
}
// Communication from agent to server during the session.
func SetupHeartBeat(commChannel CommChannel) {
go func() {
for {
time.Sleep(10 * time.Second)
err := Send(commChannel.SideChannel,
ConvergeMessage{
Value: HeartBeat{},
})
if err != nil {
log.Println("Sending heartbeat to server failed")
}
}
}()
}
func ListenForAgentEvents(channel GOBChannel,
agentInfo func(agent EnvironmentInfo),
sessionInfo func(session SessionInfo),
expiryTimeUpdate func(session ExpiryTimeUpdate),
heartBeat func(heartbeat HeartBeat)) {
for {
var result ConvergeMessage
err := channel.Decoder.Decode(&result)
if err != nil {
log.Printf("Exiting agent listener %v", err)
return
}
switch v := result.Value.(type) {
case EnvironmentInfo:
agentInfo(v)
case SessionInfo:
sessionInfo(v)
case ExpiryTimeUpdate:
expiryTimeUpdate(v)
case HeartBeat:
// for not ignoring, can also implement behavior
// when heartbeat not received but hearbeat is only
// intended to keep the connection up
heartBeat(v)
default:
fmt.Printf(" Unknown type: %v %T\n", v, v)
}
}
}
func ListenForServerEvents(channel CommChannel) {
for {
var result ConvergeMessage
err := channel.SideChannel.Decoder.Decode(&result)
if err != nil {
log.Printf("Exiting agent listener %v", err)
return
}
// no supported server events at this time.
switch v := result.Value.(type) {
default:
fmt.Printf(" Unknown type: %T\n", v)
}
}
}
func AgentInitialization(conn io.ReadWriter, agentInto EnvironmentInfo) (ServerInfo, error) {
channel := NewGOBChannel(conn)
err := CheckProtocolVersion(Agent, channel)
if err != nil {
return ServerInfo{}, err
}
err = SendWithTimeout(channel, agentInto)
if err != nil {
return ServerInfo{}, nil
}
serverInfo, err := ReceiveWithTimeout[ServerInfo](channel)
if err != nil {
return ServerInfo{}, nil
}
log.Println("Agent configuration received from server")
return serverInfo, err
}
func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (EnvironmentInfo, error) {
channel := NewGOBChannel(conn)
err := CheckProtocolVersion(ConvergeServer, channel)
if err != nil {
return EnvironmentInfo{}, err
}
agentInfo, err := ReceiveWithTimeout[EnvironmentInfo](channel)
if err != nil {
return EnvironmentInfo{}, err
}
log.Println("Agent info received: ", agentInfo)
err = SendWithTimeout(channel, serverInfo)
if err != nil {
return EnvironmentInfo{}, nil
}
return agentInfo, err
}
// Events sent over the websocket connection that is established between
// agent and converge server. This is done as soon as the agent starts.
// First commmunication between agent and Converge Server
// Both exchange their protocol version and if it is incorrect, the session
// is terminated.
func CheckProtocolVersion(role Role, channel GOBChannel) error {
switch role {
case Agent:
err := SendWithTimeout(channel, ProtocolVersion{Version: agentProtocolVersion})
if err != nil {
return err
}
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
if err != nil {
return err
}
if version.Version != agentProtocolVersion {
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
agentProtocolVersion, version.Version)
}
return nil
case ConvergeServer:
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
if err != nil {
return err
}
err = SendWithTimeout(channel, ProtocolVersion{Version: serverProtocolVersion})
if err != nil {
return err
}
if version.Version != serverProtocolVersion {
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
serverProtocolVersion, version.Version)
}
return nil
default:
panic(fmt.Errorf("unexpected role %v", role))
}
}
// Session info metadata exchange. These are sent over the SSH connection. The agent embedded
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
// metadata before the connection is handed back to SSH.
func SendClientInfo(conn io.ReadWriter, info ClientInfo) error {
channel := NewGOBChannel(conn)
return SendWithTimeout(channel, info)
}
func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) {
channel := NewGOBChannel(conn)
clientInfo, err := ReceiveWithTimeout[ClientInfo](channel)
if err != nil {
return ClientInfo{}, err
}
return clientInfo, nil
}
// message sent on the initial connection from server to agent to confirm the registration
func SendRegistrationMessage(conn io.ReadWriter, registration AgentRegistration) error {
channel := NewGOBChannel(conn)
return SendWithTimeout(channel, registration)
}
func ReceiveRegistrationMessage(conn io.ReadWriter) (AgentRegistration, error) {
channel := NewGOBChannel(conn)
registration, err := ReceiveWithTimeout[AgentRegistration](channel)
if err != nil {
return AgentRegistration{}, err
}
return registration, nil
}