converge/pkg/comms/agentserver.go
Erik Brakkee 28b2545163 test for connecting clients and bidirectional communication to agent.
Required lots of rework since the GOBChannel appeared to be reading
ahead of the data it actually needed. Now using more low-level IO
to send the clientId over to the agent instead.
2024-09-08 11:16:49 +02:00

272 lines
6.8 KiB
Go

package comms
import (
"encoding/binary"
"fmt"
"github.com/hashicorp/yamux"
"io"
"log"
"time"
)
const MESSAGE_TIMEOUT = 10 * time.Second
type CommChannel struct {
// a separate connection outside of the ssh session
SideChannel GOBChannel
Session *yamux.Session
}
type Role int
const (
Agent Role = iota
ConvergeServer
)
func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
var commChannel CommChannel
switch role {
case Agent:
listener, err := yamux.Server(wsConn, nil)
if err != nil {
return CommChannel{}, err
}
commChannel = CommChannel{
Session: listener,
}
case ConvergeServer:
clientSession, err := yamux.Client(wsConn, nil)
if err != nil {
return CommChannel{}, err
}
commChannel = CommChannel{
Session: clientSession,
}
default:
panic(fmt.Errorf("Undefined role %d", role))
}
// communication between Agent and ConvergeServer
// Currently used only fof communication from Agent to ConvergeServer
switch role {
case Agent:
conn, err := commChannel.Session.OpenStream()
commChannel.SideChannel = NewGOBChannel(conn)
if err != nil {
return CommChannel{}, err
}
case ConvergeServer:
conn, err := commChannel.Session.Accept()
commChannel.SideChannel = NewGOBChannel(conn)
if err != nil {
return CommChannel{}, err
}
default:
panic(fmt.Errorf("Undefined role %d", role))
}
log.Println("Communication channel between agent and converge server established")
return commChannel, nil
}
// Communication from agent to server during the session.
func SetupHeartBeat(commChannel CommChannel) {
go func() {
for {
time.Sleep(10 * time.Second)
err := Send(commChannel.SideChannel,
ConvergeMessage{
Value: HeartBeat{},
})
if err != nil {
log.Println("Sending heartbeat to server failed")
}
}
}()
}
func ListenForAgentEvents(channel GOBChannel,
agentInfo func(agent EnvironmentInfo),
sessionInfo func(session SessionInfo),
expiryTimeUpdate func(session ExpiryTimeUpdate),
heartBeat func(heartbeat HeartBeat)) {
for {
var result ConvergeMessage
err := channel.Decoder.Decode(&result)
if err != nil {
log.Printf("Exiting agent listener %v", err)
return
}
switch v := result.Value.(type) {
case EnvironmentInfo:
agentInfo(v)
case SessionInfo:
sessionInfo(v)
case ExpiryTimeUpdate:
expiryTimeUpdate(v)
case HeartBeat:
// for not ignoring, can also implement behavior
// when heartbeat not received but hearbeat is only
// intended to keep the connection up
heartBeat(v)
default:
fmt.Printf(" Unknown type: %v %T\n", v, v)
}
}
}
func ListenForServerEvents(channel CommChannel) {
for {
var result ConvergeMessage
err := channel.SideChannel.Decoder.Decode(&result)
if err != nil {
log.Printf("Exiting agent listener %v", err)
return
}
// no supported server events at this time.
switch v := result.Value.(type) {
default:
fmt.Printf(" Unknown type: %T\n", v)
}
}
}
func AgentInitialization(conn io.ReadWriter, agentInto EnvironmentInfo) (ServerInfo, error) {
channel := NewGOBChannel(conn)
err := CheckProtocolVersion(Agent, channel)
if err != nil {
return ServerInfo{}, err
}
err = SendWithTimeout(channel, agentInto)
if err != nil {
return ServerInfo{}, nil
}
serverInfo, err := ReceiveWithTimeout[ServerInfo](channel)
if err != nil {
return ServerInfo{}, nil
}
log.Println("Agent configuration received from server")
return serverInfo, err
}
func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (EnvironmentInfo, error) {
channel := NewGOBChannel(conn)
err := CheckProtocolVersion(ConvergeServer, channel)
if err != nil {
return EnvironmentInfo{}, err
}
agentInfo, err := ReceiveWithTimeout[EnvironmentInfo](channel)
if err != nil {
return EnvironmentInfo{}, err
}
log.Println("Agent info received: ", agentInfo)
err = SendWithTimeout(channel, serverInfo)
if err != nil {
return EnvironmentInfo{}, nil
}
return agentInfo, err
}
// Events sent over the websocket connection that is established between
// agent and converge server. This is done as soon as the agent starts.
// First commmunication between agent and Converge Server
// Both exchange their protocol version and if it is incorrect, the session
// is terminated.
func CheckProtocolVersion(role Role, channel GOBChannel) error {
switch role {
case Agent:
err := SendWithTimeout(channel, ProtocolVersion{Version: agentProtocolVersion})
if err != nil {
return err
}
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
if err != nil {
return err
}
if version.Version != agentProtocolVersion {
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
agentProtocolVersion, version.Version)
}
return nil
case ConvergeServer:
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
if err != nil {
return err
}
err = SendWithTimeout(channel, ProtocolVersion{Version: serverProtocolVersion})
if err != nil {
return err
}
if version.Version != serverProtocolVersion {
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
serverProtocolVersion, version.Version)
}
return nil
default:
panic(fmt.Errorf("unexpected role %v", role))
}
}
// Session info metadata exchange. These are sent over the SSH connection. The agent embedded
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
// metadata before the connection is handed back to SSH.
// Cannot use GOB for sending clientinfo since this involves mixing of buffered reads by
// GOB with ather reads. Alternatively, we could wrap the GOB message and encode its length,
// and then read the exacct number of bytes when decodeing. But since the clientInfo is just
// a string, this is easier.
func SendClientInfo(conn io.Writer, info string) error {
err := binary.Write(conn, binary.BigEndian, uint32(len(info)))
if err != nil {
return err
}
_, err = conn.Write([]byte(info))
return err
}
func ReceiveClientInfo(conn io.Reader) (string, error) {
var length uint32
err := binary.Read(conn, binary.BigEndian, &length)
if err != nil {
return "", err
}
bytes := make([]byte, length)
_, err = io.ReadFull(conn, bytes)
if err != nil {
return "", err
}
return string(bytes), nil
}
// message sent on the initial connection from server to agent to confirm the registration
func SendRegistrationMessage(conn io.ReadWriter, registration AgentRegistration) error {
channel := NewGOBChannel(conn)
return SendWithTimeout(channel, registration)
}
func ReceiveRegistrationMessage(conn io.ReadWriter) (AgentRegistration, error) {
channel := NewGOBChannel(conn)
registration, err := ReceiveWithTimeout[AgentRegistration](channel)
if err != nil {
return AgentRegistration{}, err
}
return registration, nil
}