diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 524ce88..00bef69 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -249,6 +249,11 @@ func main() { wsConn := websocketutil.NewWebSocketConn(conn) defer wsConn.Close() + err = comms.CheckProtocolVersion(comms.Agent, wsConn) + if err != nil { + os.Exit(1) + } + commChannel, err := comms.NewCommChannel(comms.Agent, wsConn) if err != nil { panic(err) diff --git a/cmd/converge/converge.go b/cmd/converge/converge.go index 1faac26..8cb6910 100644 --- a/cmd/converge/converge.go +++ b/cmd/converge/converge.go @@ -51,7 +51,7 @@ func printHelp(msg string) { } func main() { - downloadDir := "downloads" + downloadDir := "../static" args := os.Args[1:] for len(args) > 0 && strings.HasPrefix(args[0], "-") { diff --git a/pkg/comms/agentlistener.go b/pkg/comms/agentlistener.go new file mode 100644 index 0000000..9c4928d --- /dev/null +++ b/pkg/comms/agentlistener.go @@ -0,0 +1,36 @@ +package comms + +import ( + "net" +) + +type AgentListener struct { + decorated net.Listener +} + +func NewAgentListener(listener net.Listener) AgentListener { + return AgentListener{decorated: listener} +} + +func (listener AgentListener) Accept() (net.Conn, error) { + conn, err := listener.decorated.Accept() + if err != nil { + return nil, err + } + + //_, err = CheckProtocolVersion(Agent, conn) + //if err != nil { + // conn.Close() + // return nil, err + //} + + return conn, nil +} + +func (listener AgentListener) Close() error { + return listener.decorated.Close() +} + +func (listener AgentListener) Addr() net.Addr { + return listener.decorated.Addr() +} diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index ba7fef2..3ad5784 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -1,7 +1,6 @@ package comms import ( - "encoding/gob" "fmt" "github.com/hashicorp/yamux" "io" @@ -11,7 +10,7 @@ import ( type CommChannel struct { // a separet connection outside of the ssh session - SideChannel TCPChannel + SideChannel GOBChannel Session *yamux.Session } @@ -51,13 +50,13 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { switch role { case Agent: conn, err := commChannel.Session.OpenStream() - commChannel.SideChannel.Peer = conn + commChannel.SideChannel = NewGOBChannel(conn) if err != nil { return CommChannel{}, err } case ConvergeServer: conn, err := commChannel.Session.Accept() - commChannel.SideChannel.Peer = conn + commChannel.SideChannel = NewGOBChannel(conn) if err != nil { return CommChannel{}, err } @@ -66,10 +65,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { } log.Println("Communication channel between agent and converge server established") - RegisterEventsWithGob() - commChannel.SideChannel.Encoder = gob.NewEncoder(commChannel.SideChannel.Peer) - commChannel.SideChannel.Decoder = gob.NewDecoder(commChannel.SideChannel.Peer) - // heartbeat if role == Agent { go func() { @@ -88,7 +83,7 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { // Sending an event to the other side -func ListenForAgentEvents(channel TCPChannel, +func ListenForAgentEvents(channel GOBChannel, agentInfo func(agent AgentInfo), sessionInfo func(session SessionInfo), expiryTimeUpdate func(session ExpiryTimeUpdate)) { @@ -144,3 +139,38 @@ func ListenForServerEvents(channel CommChannel, } } } + +func CheckProtocolVersion(role Role, conn io.ReadWriter) error { + channel := NewGOBChannel(conn) + + sends := make(chan any) + receives := make(chan any) + errors := make(chan error) + + channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors) + channel.ReceiveAsync(receives, errors) + + select { + case <-time.After(10 * time.Second): + log.Println("PROTOCOLVERSION: timeout") + return fmt.Errorf("Timeout waiting for protocol version") + case err := <-errors: + log.Printf("PROTOCOLVERSION: %v", err) + return err + case protocolVersion := <-receives: + otherVersion := protocolVersion.(ProtocolVersion).Version + if PROTOCOL_VERSION != otherVersion { + switch role { + case Agent: + log.Printf("Protocol version mismatch: agent %d, converge server %d", + PROTOCOL_VERSION, otherVersion) + case ConvergeServer: + log.Printf("Protocol version mismatch: agent %d, converge server %d", + otherVersion, PROTOCOL_VERSION) + } + return fmt.Errorf("Protocol version mismatch") + } + log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version) + return nil + } +} diff --git a/pkg/comms/binary.go b/pkg/comms/binary.go deleted file mode 100644 index 13e48d3..0000000 --- a/pkg/comms/binary.go +++ /dev/null @@ -1,92 +0,0 @@ -package comms - -import ( - "encoding/binary" - "fmt" - "io" - "log" - "net" - "time" -) - -const protocol_version = 1 - -// this file contains utilities for the binary protocol as well as a listener that wraps an existing -// listener and does an exchange with the client as part of accepting the connection before -// passing it on the application. This is used to pass metadata from converge server to the ssh agent -// so that messages from agent to converge server can be correlated with the client ssh session. - -func SendInt(w io.Writer, val int) error { - val32 := int32(val) - return binary.Write(w, binary.BigEndian, val32) -} - -func ReceiveInt(r io.Reader) (int, error) { - var val int32 - err := binary.Read(r, binary.BigEndian, &val) - return int(val), err -} - -type AgentListener struct { - decorated net.Listener -} - -func NewAgentListener(listener net.Listener) AgentListener { - return AgentListener{decorated: listener} -} - -func ExchangeProtocolVersion(version int, conn io.ReadWriter) (int, error) { - errors := make(chan error) - values := make(chan int) - - go func() { - err := SendInt(conn, version) - if err != nil { - errors <- err - } - }() - go func() { - val, err := ReceiveInt(conn) - if err != nil { - errors <- err - } else { - values <- val - } - - }() - - select { - case err := <-errors: - log.Printf("Error exchanging protocol version %v", err) - return 0, err - case <-time.After(10 * time.Second): - log.Println("Timeout exchanging protocol version") - return 0, fmt.Errorf("Timeout echangeing protocol version with converge server") - case val := <-values: - log.Printf("ExchangeProtocolVersion: DEBUG: Got value %v", val) - return val, nil - } -} - -func (listener AgentListener) Accept() (net.Conn, error) { - conn, err := listener.decorated.Accept() - if err != nil { - return nil, err - } - - _, err = ExchangeProtocolVersion(99, conn) - if err != nil { - conn.Close() - return nil, err - } - - return conn, nil -} - -func (listener AgentListener) Close() error { - return listener.decorated.Close() -} - -func (listener AgentListener) Addr() net.Addr { - return listener.decorated.Addr() -} diff --git a/pkg/comms/events.go b/pkg/comms/events.go index b5b4e60..3147945 100644 --- a/pkg/comms/events.go +++ b/pkg/comms/events.go @@ -8,6 +8,12 @@ import ( "time" ) +const PROTOCOL_VERSION = 1 + +func init() { + RegisterEventsWithGob() +} + // Client to server events type AgentInfo struct { @@ -32,11 +38,20 @@ type HeartBeat struct { // Message sent from converge server to agent +type ProtocolVersion struct { + Version int +} + type UserPassword struct { Username string Password string } +type ConnectionInfo struct { + ConnectionId int + UserPassword UserPassword +} + // Generic wrapper message required to send messages of arbitrary type type ConvergeMessage struct { @@ -71,7 +86,9 @@ func RegisterEventsWithGob() { gob.Register(HeartBeat{}) // ConvergeServer to Agent + gob.Register(ProtocolVersion{}) gob.Register(UserPassword{}) + gob.Register(ConnectionInfo{}) // Wrapper event. gob.Register(ConvergeMessage{}) diff --git a/pkg/comms/gobchannel.go b/pkg/comms/gobchannel.go new file mode 100644 index 0000000..d5ed654 --- /dev/null +++ b/pkg/comms/gobchannel.go @@ -0,0 +1,62 @@ +package comms + +import ( + "encoding/gob" + "io" + "log" +) + +type GOBChannel struct { + // can be any connection, including the ssh connnection before it is + // passed on to SSH during initialization of converge to agent communication + Peer io.ReadWriter + Encoder *gob.Encoder + Decoder *gob.Decoder +} + +func NewGOBChannel(conn io.ReadWriter) GOBChannel { + return GOBChannel{ + Peer: conn, + Encoder: gob.NewEncoder(conn), + Decoder: gob.NewDecoder(conn), + } +} + +// Asynchronous send and receive on a single connection is guaranteed to preserver ordering of +// messages. We use asynchronous to void blocking indefinitely or depending on network timeouts. + +func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- error) { + go func() { + err := channel.Send(obj) + if err != nil { + errors <- err + } else { + done <- true + } + }() +} + +func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) { + go func() { + value, err := channel.Receive() + if err != nil { + errors <- err + } else { + result <- value + } + }() +} + +func (channel GOBChannel) Send(object any) error { + err := channel.Encoder.Encode(ConvergeMessage{Value: object}) + if err != nil { + log.Printf("Encoding error %v", err) + } + return err +} + +func (channel GOBChannel) Receive() (any, error) { + var target ConvergeMessage + err := channel.Decoder.Decode(&target) + return target.Value, err +} diff --git a/pkg/comms/tcpchannel.go b/pkg/comms/tcpchannel.go deleted file mode 100644 index 9eb5ade..0000000 --- a/pkg/comms/tcpchannel.go +++ /dev/null @@ -1,33 +0,0 @@ -package comms - -import ( - "encoding/gob" - "log" - "net" - "time" -) - -type TCPChannel struct { - // can be any connection, including the ssh connnection before it is - // passed on to SSH during initialization of converge to agent communication - Peer net.Conn - Encoder *gob.Encoder - Decoder *gob.Decoder -} - -// Synchronous functions with timeouts and error handling. -func (channel TCPChannel) SendAsync(object any, timeout time.Duration) error { - return nil -} - -func (channel TCPChannel) ReceiveAsync(object any, timeout time.Duration) error { - return nil -} - -func (channel TCPChannel) Send(object any) error { - err := channel.Encoder.Encode(ConvergeMessage{Value: object}) - if err != nil { - log.Printf("Encoding error %v", err) - } - return err -} diff --git a/pkg/converge/admin.go b/pkg/converge/admin.go index ad302d8..1489377 100644 --- a/pkg/converge/admin.go +++ b/pkg/converge/admin.go @@ -171,6 +171,12 @@ func (admin *Admin) RemoveClient(client *Client) error { func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser, userPassword comms.UserPassword) error { defer conn.Close() + + err := comms.CheckProtocolVersion(comms.ConvergeServer, conn) + if err != nil { + return err + } + // TODO: remove agent return value agent, err := admin.addAgent(publicId, conn) if err != nil { @@ -223,8 +229,6 @@ func (admin *Admin) Connect(publicId string, conn iowrappers.ReadWriteAddrCloser }() log.Printf("Connecting client and agent: '%s'\n", publicId) - comms.ExchangeProtocolVersion(1111, client.agent) - iowrappers.SynchronizeStreams(client.client, client.agent) return nil }