From 621bbd8ca6a3e8abec55874bce88b367a271fd77 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Sat, 27 Jul 2024 11:21:35 +0200 Subject: [PATCH] GOB channel for easily and asynchronously using GOB on a single network connection, also dealing with timeouts and errors in a good way. Protocol version is now checked when the agent connects to the converge server. Next up: sending connection metadata and username password from server to agent and sending environment information back to the server. This means then that the side channel will only be used for expiry time messages and session type with the client id passed in so the converge server can than correlate the results back to the correct channel. --- cmd/agent/agent.go | 5 +++ cmd/converge/converge.go | 2 +- pkg/comms/agentlistener.go | 36 +++++++++++++++ pkg/comms/agentserver.go | 48 ++++++++++++++++---- pkg/comms/binary.go | 92 -------------------------------------- pkg/comms/events.go | 17 +++++++ pkg/comms/gobchannel.go | 62 +++++++++++++++++++++++++ pkg/comms/tcpchannel.go | 33 -------------- pkg/converge/admin.go | 8 +++- 9 files changed, 166 insertions(+), 137 deletions(-) create mode 100644 pkg/comms/agentlistener.go delete mode 100644 pkg/comms/binary.go create mode 100644 pkg/comms/gobchannel.go delete mode 100644 pkg/comms/tcpchannel.go diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 524ce88..00bef69 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -249,6 +249,11 @@ func main() { wsConn := websocketutil.NewWebSocketConn(conn) defer wsConn.Close() + err = comms.CheckProtocolVersion(comms.Agent, wsConn) + if err != nil { + os.Exit(1) + } + commChannel, err := comms.NewCommChannel(comms.Agent, wsConn) if err != nil { panic(err) diff --git a/cmd/converge/converge.go b/cmd/converge/converge.go index 1faac26..8cb6910 100644 --- a/cmd/converge/converge.go +++ b/cmd/converge/converge.go @@ -51,7 +51,7 @@ func printHelp(msg string) { } func main() { - downloadDir := "downloads" + downloadDir := "../static" args := os.Args[1:] for len(args) > 0 && strings.HasPrefix(args[0], "-") { diff --git a/pkg/comms/agentlistener.go b/pkg/comms/agentlistener.go new file mode 100644 index 0000000..9c4928d --- /dev/null +++ b/pkg/comms/agentlistener.go @@ -0,0 +1,36 @@ +package comms + +import ( + "net" +) + +type AgentListener struct { + decorated net.Listener +} + +func NewAgentListener(listener net.Listener) AgentListener { + return AgentListener{decorated: listener} +} + +func (listener AgentListener) Accept() (net.Conn, error) { + conn, err := listener.decorated.Accept() + if err != nil { + return nil, err + } + + //_, err = CheckProtocolVersion(Agent, conn) + //if err != nil { + // conn.Close() + // return nil, err + //} + + return conn, nil +} + +func (listener AgentListener) Close() error { + return listener.decorated.Close() +} + +func (listener AgentListener) Addr() net.Addr { + return listener.decorated.Addr() +} diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index ba7fef2..3ad5784 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -1,7 +1,6 @@ package comms import ( - "encoding/gob" "fmt" "github.com/hashicorp/yamux" "io" @@ -11,7 +10,7 @@ import ( type CommChannel struct { // a separet connection outside of the ssh session - SideChannel TCPChannel + SideChannel GOBChannel Session *yamux.Session } @@ -51,13 +50,13 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { switch role { case Agent: conn, err := commChannel.Session.OpenStream() - commChannel.SideChannel.Peer = conn + commChannel.SideChannel = NewGOBChannel(conn) if err != nil { return CommChannel{}, err } case ConvergeServer: conn, err := commChannel.Session.Accept() - commChannel.SideChannel.Peer = conn + commChannel.SideChannel = NewGOBChannel(conn) if err != nil { return CommChannel{}, err } @@ -66,10 +65,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { } log.Println("Communication channel between agent and converge server established") - RegisterEventsWithGob() - commChannel.SideChannel.Encoder = gob.NewEncoder(commChannel.SideChannel.Peer) - commChannel.SideChannel.Decoder = gob.NewDecoder(commChannel.SideChannel.Peer) - // heartbeat if role == Agent { go func() { @@ -88,7 +83,7 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { // Sending an event to the other side -func ListenForAgentEvents(channel TCPChannel, +func ListenForAgentEvents(channel GOBChannel, agentInfo func(agent AgentInfo), sessionInfo func(session SessionInfo), expiryTimeUpdate func(session ExpiryTimeUpdate)) { @@ -144,3 +139,38 @@ func ListenForServerEvents(channel CommChannel, } } } + +func CheckProtocolVersion(role Role, conn io.ReadWriter) error { + channel := NewGOBChannel(conn) + + sends := make(chan any) + receives := make(chan any) + errors := make(chan error) + + channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors) + channel.ReceiveAsync(receives, errors) + + select { + case <-time.After(10 * time.Second): + log.Println("PROTOCOLVERSION: timeout") + return fmt.Errorf("Timeout waiting for protocol version") + case err := <-errors: + log.Printf("PROTOCOLVERSION: %v", err) + return err + case protocolVersion := <-receives: + otherVersion := protocolVersion.(ProtocolVersion).Version + if PROTOCOL_VERSION != otherVersion { + switch role { + case Agent: + log.Printf("Protocol version mismatch: agent %d, converge server %d", + PROTOCOL_VERSION, otherVersion) + case ConvergeServer: + log.Printf("Protocol version mismatch: agent %d, converge server %d", + otherVersion, PROTOCOL_VERSION) + } + return fmt.Errorf("Protocol version mismatch") + } + log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version) + return nil + } +} diff --git a/pkg/comms/binary.go b/pkg/comms/binary.go deleted file mode 100644 index 13e48d3..0000000 --- a/pkg/comms/binary.go +++ /dev/null @@ -1,92 +0,0 @@ -package comms - -import ( - "encoding/binary" - "fmt" - "io" - "log" - "net" - "time" -) - -const protocol_version = 1 - -// this file contains utilities for the binary protocol as well as a listener that wraps an existing -// listener and does an exchange with the client as part of accepting the connection before -// passing it on the application. This is used to pass metadata from converge server to the ssh agent -// so that messages from agent to converge server can be correlated with the client ssh session. - -func SendInt(w io.Writer, val int) error { - val32 := int32(val) - return binary.Write(w, binary.BigEndian, val32) -} - -func ReceiveInt(r io.Reader) (int, error) { - var val int32 - err := binary.Read(r, binary.BigEndian, &val) - return int(val), err -} - -type AgentListener struct { - decorated net.Listener -} - -func NewAgentListener(listener net.Listener) AgentListener { - return AgentListener{decorated: listener} -} - -func ExchangeProtocolVersion(version int, conn io.ReadWriter) (int, error) { - errors := make(chan error) - values := make(chan int) - - go func() { - err := SendInt(conn, version) - if err != nil { - errors <- err - } - }() - go func() { - val, err := ReceiveInt(conn) - if err != nil { - errors <- err - } else { - values <- val - } - - }() - - select { - case err := <-errors: - log.Printf("Error exchanging protocol version %v", err) - return 0, err - case <-time.After(10 * time.Second): - log.Println("Timeout exchanging protocol version") - return 0, fmt.Errorf("Timeout echangeing protocol version with converge server") - case val := <-values: - log.Printf("ExchangeProtocolVersion: DEBUG: Got value %v", val) - return val, nil - } -} - -func (listener AgentListener) Accept() (net.Conn, error) { - conn, err := listener.decorated.Accept() - if err != nil { - return nil, err - } - - _, err = ExchangeProtocolVersion(99, conn) - if err != nil { - conn.Close() - return nil, err - } - - return conn, nil -} - -func (listener AgentListener) Close() error { - return listener.decorated.Close() -} - -func (listener AgentListener) Addr() net.Addr { - return listener.decorated.Addr() -} diff --git a/pkg/comms/events.go b/pkg/comms/events.go index b5b4e60..3147945 100644 --- a/pkg/comms/events.go +++ b/pkg/comms/events.go @@ -8,6 +8,12 @@ import ( "time" ) +const PROTOCOL_VERSION = 1 + +func init() { + RegisterEventsWithGob() +} + // Client to server events type AgentInfo struct { @@ -32,11 +38,20 @@ type HeartBeat struct { // Message sent from converge server to agent +type ProtocolVersion struct { + Version int +} + type UserPassword struct { Username string Password string } +type ConnectionInfo struct { + ConnectionId int + UserPassword UserPassword +} + // Generic wrapper message required to send messages of arbitrary type type ConvergeMessage struct { @@ -71,7 +86,9 @@ func RegisterEventsWithGob() { gob.Register(HeartBeat{}) // ConvergeServer to Agent + gob.Register(ProtocolVersion{}) gob.Register(UserPassword{}) + gob.Register(ConnectionInfo{}) // Wrapper event. gob.Register(ConvergeMessage{}) diff --git a/pkg/comms/gobchannel.go b/pkg/comms/gobchannel.go new file mode 100644 index 0000000..d5ed654 --- /dev/null +++ b/pkg/comms/gobchannel.go @@ -0,0 +1,62 @@ +package comms + +import ( + "encoding/gob" + "io" + "log" +) + +type GOBChannel struct { + // can be any connection, including the ssh connnection before it is + // passed on to SSH during initialization of converge to agent communication + Peer io.ReadWriter + Encoder *gob.Encoder + Decoder *gob.Decoder +} + +func NewGOBChannel(conn io.ReadWriter) GOBChannel { + return GOBChannel{ + Peer: conn, + Encoder: gob.NewEncoder(conn), + Decoder: gob.NewDecoder(conn), + } +} + +// Asynchronous send and receive on a single connection is guaranteed to preserver ordering of +// messages. We use asynchronous to void blocking indefinitely or depending on network timeouts. + +func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- error) { + go func() { + err := channel.Send(obj) + if err != nil { + errors <- err + } else { + done <- true + } + }() +} + +func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) { + go func() { + value, err := channel.Receive() + if err != nil { + errors <- err + } else { + result <- value + } + }() +} + +func (channel GOBChannel) Send(object any) error { + err := channel.Encoder.Encode(ConvergeMessage{Value: object}) + if err != nil { + log.Printf("Encoding error %v", err) + } + return err +} + +func (channel GOBChannel) Receive() (any, error) { + var target ConvergeMessage + err := channel.Decoder.Decode(&target) + return target.Value, err +} diff --git a/pkg/comms/tcpchannel.go b/pkg/comms/tcpchannel.go deleted file mode 100644 index 9eb5ade..0000000 --- a/pkg/comms/tcpchannel.go +++ /dev/null @@ -1,33 +0,0 @@ -package comms - -import ( - "encoding/gob" - "log" - "net" - "time" -) - -type TCPChannel struct { - // can be any connection, including the ssh connnection before it is - // passed on to SSH during initialization of converge to agent communication - Peer net.Conn - Encoder *gob.Encoder - Decoder *gob.Decoder -} - -// Synchronous functions with timeouts and error handling. -func (channel TCPChannel) SendAsync(object any, timeout time.Duration) error { - return nil -} - -func (channel TCPChannel) ReceiveAsync(object any, timeout time.Duration) error { - return nil -} - -func (channel TCPChannel) Send(object any) error { - err := channel.Encoder.Encode(ConvergeMessage{Value: object}) - if err != nil { - log.Printf("Encoding error %v", err) - } - return err -} diff --git a/pkg/converge/admin.go b/pkg/converge/admin.go index ad302d8..1489377 100644 --- a/pkg/converge/admin.go +++ b/pkg/converge/admin.go @@ -171,6 +171,12 @@ func (admin *Admin) RemoveClient(client *Client) error { func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser, userPassword comms.UserPassword) error { defer conn.Close() + + err := comms.CheckProtocolVersion(comms.ConvergeServer, conn) + if err != nil { + return err + } + // TODO: remove agent return value agent, err := admin.addAgent(publicId, conn) if err != nil { @@ -223,8 +229,6 @@ func (admin *Admin) Connect(publicId string, conn iowrappers.ReadWriteAddrCloser }() log.Printf("Connecting client and agent: '%s'\n", publicId) - comms.ExchangeProtocolVersion(1111, client.agent) - iowrappers.SynchronizeStreams(client.client, client.agent) return nil }