package comms import ( "fmt" "github.com/hashicorp/yamux" "io" "log" "time" ) const MESSAGE_TIMEOUT = 10 * time.Second type CommChannel struct { // a separet connection outside of the ssh session SideChannel GOBChannel Session *yamux.Session } type Role int const ( Agent Role = iota ConvergeServer ) func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { var commChannel CommChannel switch role { case Agent: listener, err := yamux.Server(wsConn, nil) if err != nil { return CommChannel{}, err } commChannel = CommChannel{ Session: listener, } case ConvergeServer: clientSession, err := yamux.Client(wsConn, nil) if err != nil { return CommChannel{}, err } commChannel = CommChannel{ Session: clientSession, } default: panic(fmt.Errorf("Undefined role %d", role)) } // communication between Agent and ConvergeServer // Currently used only fof communication from Agent to ConvergeServer switch role { case Agent: conn, err := commChannel.Session.OpenStream() commChannel.SideChannel = NewGOBChannel(conn) if err != nil { return CommChannel{}, err } case ConvergeServer: conn, err := commChannel.Session.Accept() commChannel.SideChannel = NewGOBChannel(conn) if err != nil { return CommChannel{}, err } default: panic(fmt.Errorf("Undefined role %d", role)) } log.Println("Communication channel between agent and converge server established") // heartbeat if role == Agent { go func() { for { time.Sleep(10 * time.Second) err := Send(commChannel.SideChannel, ConvergeMessage{ Value: HeartBeat{}, }) if err != nil { log.Println("Sending heartbeat to server failed") } } }() } return commChannel, nil } // Sending an event to the other side func ListenForAgentEvents(channel GOBChannel, agentInfo func(agent AgentInfo), sessionInfo func(session SessionInfo), expiryTimeUpdate func(session ExpiryTimeUpdate)) { for { var result ConvergeMessage err := channel.Decoder.Decode(&result) if err != nil { log.Printf("Exiting agent listener %v", err) return } switch v := result.Value.(type) { case AgentInfo: agentInfo(v) case SessionInfo: sessionInfo(v) case ExpiryTimeUpdate: expiryTimeUpdate(v) case HeartBeat: // for not ignoring, can also implement behavior // when heartbeat not received but hearbeat is only // intended to keep the connection up default: fmt.Printf(" Unknown type: %v %T\n", v, v) } } } func ListenForServerEvents(channel CommChannel) { for { var result ConvergeMessage err := channel.SideChannel.Decoder.Decode(&result) if err != nil { log.Printf("Exiting agent listener %v", err) return } // no supported server events at this time. switch v := result.Value.(type) { default: fmt.Printf(" Unknown type: %T\n", v) } } } func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, error) { channel := NewGOBChannel(conn) err := CheckProtocolVersion(Agent, channel) err = SendWithTimeout(channel, agentInto) if err != nil { return ServerInfo{}, nil } serverInfo, err := ReceiveWithTimeout[ServerInfo](channel) if err != nil { return ServerInfo{}, nil } // TODO remove logging log.Println("Agent configuration received from server") return serverInfo, err } func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo, error) { channel := NewGOBChannel(conn) err := CheckProtocolVersion(ConvergeServer, channel) agentInfo, err := ReceiveWithTimeout[AgentInfo](channel) if err != nil { return AgentInfo{}, err } log.Println("Agent info received: ", agentInfo) err = SendWithTimeout(channel, serverInfo) if err != nil { return AgentInfo{}, nil } return agentInfo, err } // Events sent over the websocket connection that is established between // agent and converge server. This is done as soon as the agent starts. // First commmunication between agent and Converge Server // Both exchange their protocol version and if it is incorrect, the session // is terminated. func CheckProtocolVersion(role Role, channel GOBChannel) error { switch role { case Agent: err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION}) if err != nil { return err } version, err := ReceiveWithTimeout[ProtocolVersion](channel) if err != nil { return err } if version.Version != PROTOCOL_VERSION { return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d", PROTOCOL_VERSION, version.Version) } return nil case ConvergeServer: version, err := ReceiveWithTimeout[ProtocolVersion](channel) if err != nil { return err } err = SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION}) if err != nil { return err } if version.Version != PROTOCOL_VERSION { return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d", PROTOCOL_VERSION, version.Version) } return nil default: panic(fmt.Errorf("unexpected rolg %v", role)) } } func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error { channel := NewGOBChannel(conn) sends := make(chan bool, 10) receives := make(chan ProtocolVersion, 10) errors := make(chan error, 10) SendAsync(channel, ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors) ReceiveAsync(channel, receives, errors) select { case <-time.After(MESSAGE_TIMEOUT): log.Println("PROTOCOLVERSION: timeout") return fmt.Errorf("Timeout waiting for protocol version") case err := <-errors: log.Printf("PROTOCOLVERSION: %v", err) return err case protocolVersion := <-receives: otherVersion := protocolVersion.Version if PROTOCOL_VERSION != otherVersion { switch role { case Agent: log.Printf("Protocol version mismatch: agent %d, converge server %d", PROTOCOL_VERSION, otherVersion) case ConvergeServer: log.Printf("Protocol version mismatch: agent %d, converge server %d", otherVersion, PROTOCOL_VERSION) } return fmt.Errorf("Protocol version mismatch") } log.Printf("PROTOCOLVERSION: %v", protocolVersion.Version) return nil } } // Session info metadata exchange. These are sent over the SSH connection. The agent embedded // ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that // decorates the yamux Session (which is a listener) and uses this connection to exchange some // metadata before the connection is handed back to SSH. func SendClientInfo(conn io.ReadWriter, info ClientInfo) error { channel := NewGOBChannel(conn) return SendWithTimeout(channel, info) } func ReceiveClientInfo(conn io.ReadWriter) (ClientInfo, error) { channel := NewGOBChannel(conn) clientInfo, err := ReceiveWithTimeout[ClientInfo](channel) if err != nil { return ClientInfo{}, err } return clientInfo, nil }