package comms import ( "fmt" "github.com/hashicorp/yamux" "io" "log" "time" ) type CommChannel struct { // a separet connection outside of the ssh session SideChannel GOBChannel Session *yamux.Session } type Role int const ( Agent Role = iota ConvergeServer ) func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { var commChannel CommChannel switch role { case Agent: listener, err := yamux.Server(wsConn, nil) if err != nil { return CommChannel{}, err } commChannel = CommChannel{ Session: listener, } case ConvergeServer: clientSession, err := yamux.Client(wsConn, nil) if err != nil { return CommChannel{}, err } commChannel = CommChannel{ Session: clientSession, } default: panic(fmt.Errorf("Undefined role %d", role)) } // communication between Agent and ConvergeServer // Currently used only fof communication from Agent to ConvergeServer switch role { case Agent: conn, err := commChannel.Session.OpenStream() commChannel.SideChannel = NewGOBChannel(conn) if err != nil { return CommChannel{}, err } case ConvergeServer: conn, err := commChannel.Session.Accept() commChannel.SideChannel = NewGOBChannel(conn) if err != nil { return CommChannel{}, err } default: panic(fmt.Errorf("Undefined role %d", role)) } log.Println("Communication channel between agent and converge server established") // heartbeat if role == Agent { go func() { for { time.Sleep(10 * time.Second) err := commChannel.SideChannel.Send(HeartBeat{}) if err != nil { log.Println("Sending heartbeat to server failed") } } }() } return commChannel, nil } // Sending an event to the other side func ListenForAgentEvents(channel GOBChannel, agentInfo func(agent AgentInfo), sessionInfo func(session SessionInfo), expiryTimeUpdate func(session ExpiryTimeUpdate)) { for { var result ConvergeMessage err := channel.Decoder.Decode(&result) if err != nil { // TODO more clean solution, need to explicitly close when agent exits. log.Printf("Exiting agent listener %v", err) return } switch v := result.Value.(type) { case AgentInfo: agentInfo(v) case SessionInfo: sessionInfo(v) case ExpiryTimeUpdate: expiryTimeUpdate(v) case HeartBeat: // for not ignoring, can also implement behavior // when heartbeat not received but hearbeat is only // intended to keep the connection up default: fmt.Printf(" Unknown type: %T\n", v) } } } func ListenForServerEvents(channel CommChannel, setUsernamePassword func(user UserPassword)) { for { var result ConvergeMessage err := channel.SideChannel.Decoder.Decode(&result) if err != nil { // TODO more clean solution, need to explicitly close when agent exits. log.Printf("Exiting agent listener %v", err) return } switch v := result.Value.(type) { case UserPassword: setUsernamePassword(v) default: fmt.Printf(" Unknown type: %T\n", v) } } } func CheckProtocolVersion(role Role, conn io.ReadWriter) error { channel := NewGOBChannel(conn) sends := make(chan any) receives := make(chan any) errors := make(chan error) channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors) channel.ReceiveAsync(receives, errors) select { case <-time.After(10 * time.Second): log.Println("PROTOCOLVERSION: timeout") return fmt.Errorf("Timeout waiting for protocol version") case err := <-errors: log.Printf("PROTOCOLVERSION: %v", err) return err case protocolVersion := <-receives: otherVersion := protocolVersion.(ProtocolVersion).Version if PROTOCOL_VERSION != otherVersion { switch role { case Agent: log.Printf("Protocol version mismatch: agent %d, converge server %d", PROTOCOL_VERSION, otherVersion) case ConvergeServer: log.Printf("Protocol version mismatch: agent %d, converge server %d", otherVersion, PROTOCOL_VERSION) } return fmt.Errorf("Protocol version mismatch") } log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version) return nil } }