the agent if PPROF_PORT is set. Fixed issue with converge server not cleaning up goroutines because of blocking channel. Made sure to create channels with > 1 size everywhere it can be done. The blocking behavior of a default channel size is mostly in the way. Known issue: Killing the SSH client will lead to the server side process not being terminated and some goroutines still running in the agent. This would require additional investigation to solve. The remote processes are still being cleaned up ok (at least on linux) when the agent exits. This should not be an issue at all since the agent is a short-lived process and when running in a containerized environment with containers running on demand the cleanup will definitely work.
274 lines
6.7 KiB
Go
274 lines
6.7 KiB
Go
package comms
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/hashicorp/yamux"
|
|
"io"
|
|
"log"
|
|
"time"
|
|
)
|
|
|
|
const MESSAGE_TIMEOUT = 10 * time.Second
|
|
|
|
type CommChannel struct {
|
|
// a separet connection outside of the ssh session
|
|
SideChannel GOBChannel
|
|
Session *yamux.Session
|
|
}
|
|
|
|
type Role int
|
|
|
|
const (
|
|
Agent Role = iota
|
|
ConvergeServer
|
|
)
|
|
|
|
func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
|
var commChannel CommChannel
|
|
|
|
switch role {
|
|
case Agent:
|
|
listener, err := yamux.Server(wsConn, nil)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
commChannel = CommChannel{
|
|
Session: listener,
|
|
}
|
|
case ConvergeServer:
|
|
clientSession, err := yamux.Client(wsConn, nil)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
commChannel = CommChannel{
|
|
Session: clientSession,
|
|
}
|
|
default:
|
|
panic(fmt.Errorf("Undefined role %d", role))
|
|
}
|
|
|
|
// communication between Agent and ConvergeServer
|
|
// Currently used only fof communication from Agent to ConvergeServer
|
|
|
|
switch role {
|
|
case Agent:
|
|
conn, err := commChannel.Session.OpenStream()
|
|
commChannel.SideChannel = NewGOBChannel(conn)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
case ConvergeServer:
|
|
conn, err := commChannel.Session.Accept()
|
|
commChannel.SideChannel = NewGOBChannel(conn)
|
|
if err != nil {
|
|
return CommChannel{}, err
|
|
}
|
|
default:
|
|
panic(fmt.Errorf("Undefined role %d", role))
|
|
}
|
|
log.Println("Communication channel between agent and converge server established")
|
|
|
|
// heartbeat
|
|
if role == Agent {
|
|
go func() {
|
|
for {
|
|
time.Sleep(10 * time.Second)
|
|
err := Send(commChannel.SideChannel,
|
|
ConvergeMessage{
|
|
Value: HeartBeat{},
|
|
})
|
|
if err != nil {
|
|
log.Println("Sending heartbeat to server failed")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
return commChannel, nil
|
|
}
|
|
|
|
// Sending an event to the other side
|
|
|
|
func ListenForAgentEvents(channel GOBChannel,
|
|
agentInfo func(agent AgentInfo),
|
|
sessionInfo func(session SessionInfo),
|
|
expiryTimeUpdate func(session ExpiryTimeUpdate)) {
|
|
for {
|
|
var result ConvergeMessage
|
|
err := channel.Decoder.Decode(&result)
|
|
|
|
if err != nil {
|
|
log.Printf("Exiting agent listener %v", err)
|
|
return
|
|
}
|
|
switch v := result.Value.(type) {
|
|
|
|
case AgentInfo:
|
|
agentInfo(v)
|
|
|
|
case SessionInfo:
|
|
sessionInfo(v)
|
|
|
|
case ExpiryTimeUpdate:
|
|
expiryTimeUpdate(v)
|
|
|
|
case HeartBeat:
|
|
// for not ignoring, can also implement behavior
|
|
// when heartbeat not received but hearbeat is only
|
|
// intended to keep the connection up
|
|
|
|
default:
|
|
fmt.Printf(" Unknown type: %v %T\n", v, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func ListenForServerEvents(channel CommChannel) {
|
|
for {
|
|
var result ConvergeMessage
|
|
err := channel.SideChannel.Decoder.Decode(&result)
|
|
|
|
if err != nil {
|
|
log.Printf("Exiting agent listener %v", err)
|
|
return
|
|
}
|
|
|
|
// no supported server events at this time.
|
|
switch v := result.Value.(type) {
|
|
|
|
default:
|
|
fmt.Printf(" Unknown type: %T\n", v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, error) {
|
|
channel := NewGOBChannel(conn)
|
|
err := CheckProtocolVersion(Agent, channel)
|
|
|
|
err = SendWithTimeout(channel, agentInto)
|
|
if err != nil {
|
|
return ServerInfo{}, nil
|
|
}
|
|
serverInfo, err := ReceiveWithTimeout[ServerInfo](channel)
|
|
if err != nil {
|
|
return ServerInfo{}, nil
|
|
}
|
|
// TODO remove logging
|
|
log.Println("Agent configuration received from server")
|
|
|
|
return serverInfo, err
|
|
}
|
|
|
|
func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo, error) {
|
|
channel := NewGOBChannel(conn)
|
|
err := CheckProtocolVersion(ConvergeServer, channel)
|
|
|
|
agentInfo, err := ReceiveWithTimeout[AgentInfo](channel)
|
|
if err != nil {
|
|
return AgentInfo{}, err
|
|
}
|
|
log.Println("Agent info received: ", agentInfo)
|
|
err = SendWithTimeout(channel, serverInfo)
|
|
if err != nil {
|
|
return AgentInfo{}, nil
|
|
}
|
|
return agentInfo, err
|
|
}
|
|
|
|
// Events sent over the websocket connection that is established between
|
|
// agent and converge server. This is done as soon as the agent starts.
|
|
|
|
// First commmunication between agent and Converge Server
|
|
// Both exchange their protocol version and if it is incorrect, the session
|
|
// is terminated.
|
|
|
|
func CheckProtocolVersion(role Role, channel GOBChannel) error {
|
|
switch role {
|
|
case Agent:
|
|
err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if version.Version != PROTOCOL_VERSION {
|
|
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
|
|
PROTOCOL_VERSION, version.Version)
|
|
}
|
|
return nil
|
|
case ConvergeServer:
|
|
version, err := ReceiveWithTimeout[ProtocolVersion](channel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if version.Version != PROTOCOL_VERSION {
|
|
return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d",
|
|
PROTOCOL_VERSION, version.Version)
|
|
}
|
|
return nil
|
|
default:
|
|
panic(fmt.Errorf("unexpected rolg %v", role))
|
|
}
|
|
}
|
|
|
|
func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error {
|
|
channel := NewGOBChannel(conn)
|
|
|
|
sends := make(chan bool, 10)
|
|
receives := make(chan ProtocolVersion, 10)
|
|
errors := make(chan error, 10)
|
|
|
|
SendAsync(channel, ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
|
|
ReceiveAsync(channel, receives, errors)
|
|
|
|
select {
|
|
case <-time.After(MESSAGE_TIMEOUT):
|
|
log.Println("PROTOCOLVERSION: timeout")
|
|
return fmt.Errorf("Timeout waiting for protocol version")
|
|
case err := <-errors:
|
|
log.Printf("PROTOCOLVERSION: %v", err)
|
|
return err
|
|
case protocolVersion := <-receives:
|
|
otherVersion := protocolVersion.Version
|
|
if PROTOCOL_VERSION != otherVersion {
|
|
switch role {
|
|
case Agent:
|
|
log.Printf("Protocol version mismatch: agent %d, converge server %d",
|
|
PROTOCOL_VERSION, otherVersion)
|
|
case ConvergeServer:
|
|
log.Printf("Protocol version mismatch: agent %d, converge server %d",
|
|
otherVersion, PROTOCOL_VERSION)
|
|
}
|
|
return fmt.Errorf("Protocol version mismatch")
|
|
}
|
|
log.Printf("PROTOCOLVERSION: %v", protocolVersion.Version)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Session info metadata exchange. These are sent over the SSH connection. The agent embedded
|
|
// ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
|
|
// decorates the yamux Session (which is a listener) and uses this connection to exchange some
|
|
// metadata before the connection is handed back to SSH.
|
|
|
|
func SendClientInfo(conn io.ReadWriter, info ClientInfo) error {
|
|
channel := NewGOBChannel(conn)
|
|
return SendWithTimeout(channel, info)
|
|
}
|
|
|
|
func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) {
|
|
channel := NewGOBChannel(conn)
|
|
clientInfo, err := ReceiveWithTimeout[ClientInfo](channel)
|
|
if err != nil {
|
|
return ClientInfo{}, err
|
|
}
|
|
return clientInfo, nil
|
|
}
|