converge/pkg/comms/agentserver.go
Erik Brakkee ada34495ef GOB channel for easily and asynchronously using GOB on a single network connection, also dealing with timeouts and errors in a good way.
Protocol version is now checked when the agent connects to the converge server.

Next up: sending connection metadata and username password from server to agent and sending environment information back to the server. This means then that the side channel will only be used for expiry time messages and session type with the client id passed in so the converge server can than correlate the results back to the correct channel.
2024-09-08 11:16:49 +02:00

177 lines
4.0 KiB
Go

package comms
import (
"fmt"
"github.com/hashicorp/yamux"
"io"
"log"
"time"
)
type CommChannel struct {
// a separet connection outside of the ssh session
SideChannel GOBChannel
Session *yamux.Session
}
type Role int
const (
Agent Role = iota
ConvergeServer
)
func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
var commChannel CommChannel
switch role {
case Agent:
listener, err := yamux.Server(wsConn, nil)
if err != nil {
return CommChannel{}, err
}
commChannel = CommChannel{
Session: listener,
}
case ConvergeServer:
clientSession, err := yamux.Client(wsConn, nil)
if err != nil {
return CommChannel{}, err
}
commChannel = CommChannel{
Session: clientSession,
}
default:
panic(fmt.Errorf("Undefined role %d", role))
}
// communication between Agent and ConvergeServer
// Currently used only fof communication from Agent to ConvergeServer
switch role {
case Agent:
conn, err := commChannel.Session.OpenStream()
commChannel.SideChannel = NewGOBChannel(conn)
if err != nil {
return CommChannel{}, err
}
case ConvergeServer:
conn, err := commChannel.Session.Accept()
commChannel.SideChannel = NewGOBChannel(conn)
if err != nil {
return CommChannel{}, err
}
default:
panic(fmt.Errorf("Undefined role %d", role))
}
log.Println("Communication channel between agent and converge server established")
// heartbeat
if role == Agent {
go func() {
for {
time.Sleep(10 * time.Second)
err := commChannel.SideChannel.Send(HeartBeat{})
if err != nil {
log.Println("Sending heartbeat to server failed")
}
}
}()
}
return commChannel, nil
}
// Sending an event to the other side
func ListenForAgentEvents(channel GOBChannel,
agentInfo func(agent AgentInfo),
sessionInfo func(session SessionInfo),
expiryTimeUpdate func(session ExpiryTimeUpdate)) {
for {
var result ConvergeMessage
err := channel.Decoder.Decode(&result)
if err != nil {
// TODO more clean solution, need to explicitly close when agent exits.
log.Printf("Exiting agent listener %v", err)
return
}
switch v := result.Value.(type) {
case AgentInfo:
agentInfo(v)
case SessionInfo:
sessionInfo(v)
case ExpiryTimeUpdate:
expiryTimeUpdate(v)
case HeartBeat:
// for not ignoring, can also implement behavior
// when heartbeat not received but hearbeat is only
// intended to keep the connection up
default:
fmt.Printf(" Unknown type: %T\n", v)
}
}
}
func ListenForServerEvents(channel CommChannel,
setUsernamePassword func(user UserPassword)) {
for {
var result ConvergeMessage
err := channel.SideChannel.Decoder.Decode(&result)
if err != nil {
// TODO more clean solution, need to explicitly close when agent exits.
log.Printf("Exiting agent listener %v", err)
return
}
switch v := result.Value.(type) {
case UserPassword:
setUsernamePassword(v)
default:
fmt.Printf(" Unknown type: %T\n", v)
}
}
}
func CheckProtocolVersion(role Role, conn io.ReadWriter) error {
channel := NewGOBChannel(conn)
sends := make(chan any)
receives := make(chan any)
errors := make(chan error)
channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
channel.ReceiveAsync(receives, errors)
select {
case <-time.After(10 * time.Second):
log.Println("PROTOCOLVERSION: timeout")
return fmt.Errorf("Timeout waiting for protocol version")
case err := <-errors:
log.Printf("PROTOCOLVERSION: %v", err)
return err
case protocolVersion := <-receives:
otherVersion := protocolVersion.(ProtocolVersion).Version
if PROTOCOL_VERSION != otherVersion {
switch role {
case Agent:
log.Printf("Protocol version mismatch: agent %d, converge server %d",
PROTOCOL_VERSION, otherVersion)
case ConvergeServer:
log.Printf("Protocol version mismatch: agent %d, converge server %d",
otherVersion, PROTOCOL_VERSION)
}
return fmt.Errorf("Protocol version mismatch")
}
log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version)
return nil
}
}