Protocol version is now checked when the agent connects to the converge server. Next up: sending connection metadata and username password from server to agent and sending environment information back to the server. This means then that the side channel will only be used for expiry time messages and session type with the client id passed in so the converge server can than correlate the results back to the correct channel.
177 lines
4.0 KiB
Go
177 lines
4.0 KiB
Go
package comms
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/hashicorp/yamux"
|
|
"io"
|
|
"log"
|
|
"time"
|
|
)
|
|
|
|
type CommChannel struct {
|
|
// a separet connection outside of the ssh session
|
|
SideChannel GOBChannel
|
|
Session *yamux.Session
|
|
}
|
|
|
|
type Role int
|
|
|
|
const (
|
|
Agent Role = iota
|
|
ConvergeServer
|
|
)
|
|
|
|
func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
|
var commChannel CommChannel
|
|
switch role {
|
|
case Agent:
|
|
listener, err := yamux.Server(wsConn, nil)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
commChannel = CommChannel{
|
|
Session: listener,
|
|
}
|
|
case ConvergeServer:
|
|
clientSession, err := yamux.Client(wsConn, nil)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
commChannel = CommChannel{
|
|
Session: clientSession,
|
|
}
|
|
default:
|
|
panic(fmt.Errorf("Undefined role %d", role))
|
|
}
|
|
|
|
// communication between Agent and ConvergeServer
|
|
// Currently used only fof communication from Agent to ConvergeServer
|
|
|
|
switch role {
|
|
case Agent:
|
|
conn, err := commChannel.Session.OpenStream()
|
|
commChannel.SideChannel = NewGOBChannel(conn)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
case ConvergeServer:
|
|
conn, err := commChannel.Session.Accept()
|
|
commChannel.SideChannel = NewGOBChannel(conn)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
default:
|
|
panic(fmt.Errorf("Undefined role %d", role))
|
|
}
|
|
log.Println("Communication channel between agent and converge server established")
|
|
|
|
// heartbeat
|
|
if role == Agent {
|
|
go func() {
|
|
for {
|
|
time.Sleep(10 * time.Second)
|
|
err := commChannel.SideChannel.Send(HeartBeat{})
|
|
if err != nil {
|
|
log.Println("Sending heartbeat to server failed")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
return commChannel, nil
|
|
}
|
|
|
|
// Sending an event to the other side
|
|
|
|
func ListenForAgentEvents(channel GOBChannel,
|
|
agentInfo func(agent AgentInfo),
|
|
sessionInfo func(session SessionInfo),
|
|
expiryTimeUpdate func(session ExpiryTimeUpdate)) {
|
|
for {
|
|
var result ConvergeMessage
|
|
err := channel.Decoder.Decode(&result)
|
|
|
|
if err != nil {
|
|
// TODO more clean solution, need to explicitly close when agent exits.
|
|
log.Printf("Exiting agent listener %v", err)
|
|
return
|
|
}
|
|
switch v := result.Value.(type) {
|
|
|
|
case AgentInfo:
|
|
agentInfo(v)
|
|
|
|
case SessionInfo:
|
|
sessionInfo(v)
|
|
|
|
case ExpiryTimeUpdate:
|
|
expiryTimeUpdate(v)
|
|
|
|
case HeartBeat:
|
|
// for not ignoring, can also implement behavior
|
|
// when heartbeat not received but hearbeat is only
|
|
// intended to keep the connection up
|
|
|
|
default:
|
|
fmt.Printf(" Unknown type: %T\n", v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func ListenForServerEvents(channel CommChannel,
|
|
setUsernamePassword func(user UserPassword)) {
|
|
for {
|
|
var result ConvergeMessage
|
|
err := channel.SideChannel.Decoder.Decode(&result)
|
|
|
|
if err != nil {
|
|
// TODO more clean solution, need to explicitly close when agent exits.
|
|
log.Printf("Exiting agent listener %v", err)
|
|
return
|
|
}
|
|
switch v := result.Value.(type) {
|
|
|
|
case UserPassword:
|
|
setUsernamePassword(v)
|
|
|
|
default:
|
|
fmt.Printf(" Unknown type: %T\n", v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func CheckProtocolVersion(role Role, conn io.ReadWriter) error {
|
|
channel := NewGOBChannel(conn)
|
|
|
|
sends := make(chan any)
|
|
receives := make(chan any)
|
|
errors := make(chan error)
|
|
|
|
channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
|
|
channel.ReceiveAsync(receives, errors)
|
|
|
|
select {
|
|
case <-time.After(10 * time.Second):
|
|
log.Println("PROTOCOLVERSION: timeout")
|
|
return fmt.Errorf("Timeout waiting for protocol version")
|
|
case err := <-errors:
|
|
log.Printf("PROTOCOLVERSION: %v", err)
|
|
return err
|
|
case protocolVersion := <-receives:
|
|
otherVersion := protocolVersion.(ProtocolVersion).Version
|
|
if PROTOCOL_VERSION != otherVersion {
|
|
switch role {
|
|
case Agent:
|
|
log.Printf("Protocol version mismatch: agent %d, converge server %d",
|
|
PROTOCOL_VERSION, otherVersion)
|
|
case ConvergeServer:
|
|
log.Printf("Protocol version mismatch: agent %d, converge server %d",
|
|
otherVersion, PROTOCOL_VERSION)
|
|
}
|
|
return fmt.Errorf("Protocol version mismatch")
|
|
}
|
|
log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version)
|
|
return nil
|
|
}
|
|
}
|