GOB channel for easily and asynchronously using GOB on a single network connection, also dealing with timeouts and errors in a good way.

Protocol version is now checked when the agent connects to the converge server.

Next up: sending connection metadata and username password from server to agent and sending environment information back to the server. This means then that the side channel will only be used for expiry time messages and session type with the client id passed in so the converge server can than correlate the results back to the correct channel.
This commit is contained in:
Erik Brakkee 2024-07-27 11:21:35 +02:00
parent f82601d07c
commit 621bbd8ca6
9 changed files with 166 additions and 137 deletions

View File

@ -249,6 +249,11 @@ func main() {
wsConn := websocketutil.NewWebSocketConn(conn)
defer wsConn.Close()
err = comms.CheckProtocolVersion(comms.Agent, wsConn)
if err != nil {
os.Exit(1)
}
commChannel, err := comms.NewCommChannel(comms.Agent, wsConn)
if err != nil {
panic(err)

View File

@ -51,7 +51,7 @@ func printHelp(msg string) {
}
func main() {
downloadDir := "downloads"
downloadDir := "../static"
args := os.Args[1:]
for len(args) > 0 && strings.HasPrefix(args[0], "-") {

View File

@ -0,0 +1,36 @@
package comms
import (
"net"
)
type AgentListener struct {
decorated net.Listener
}
func NewAgentListener(listener net.Listener) AgentListener {
return AgentListener{decorated: listener}
}
func (listener AgentListener) Accept() (net.Conn, error) {
conn, err := listener.decorated.Accept()
if err != nil {
return nil, err
}
//_, err = CheckProtocolVersion(Agent, conn)
//if err != nil {
// conn.Close()
// return nil, err
//}
return conn, nil
}
func (listener AgentListener) Close() error {
return listener.decorated.Close()
}
func (listener AgentListener) Addr() net.Addr {
return listener.decorated.Addr()
}

View File

@ -1,7 +1,6 @@
package comms
import (
"encoding/gob"
"fmt"
"github.com/hashicorp/yamux"
"io"
@ -11,7 +10,7 @@ import (
type CommChannel struct {
// a separet connection outside of the ssh session
SideChannel TCPChannel
SideChannel GOBChannel
Session *yamux.Session
}
@ -51,13 +50,13 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
switch role {
case Agent:
conn, err := commChannel.Session.OpenStream()
commChannel.SideChannel.Peer = conn
commChannel.SideChannel = NewGOBChannel(conn)
if err != nil {
return CommChannel{}, err
}
case ConvergeServer:
conn, err := commChannel.Session.Accept()
commChannel.SideChannel.Peer = conn
commChannel.SideChannel = NewGOBChannel(conn)
if err != nil {
return CommChannel{}, err
}
@ -66,10 +65,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
}
log.Println("Communication channel between agent and converge server established")
RegisterEventsWithGob()
commChannel.SideChannel.Encoder = gob.NewEncoder(commChannel.SideChannel.Peer)
commChannel.SideChannel.Decoder = gob.NewDecoder(commChannel.SideChannel.Peer)
// heartbeat
if role == Agent {
go func() {
@ -88,7 +83,7 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
// Sending an event to the other side
func ListenForAgentEvents(channel TCPChannel,
func ListenForAgentEvents(channel GOBChannel,
agentInfo func(agent AgentInfo),
sessionInfo func(session SessionInfo),
expiryTimeUpdate func(session ExpiryTimeUpdate)) {
@ -144,3 +139,38 @@ func ListenForServerEvents(channel CommChannel,
}
}
}
func CheckProtocolVersion(role Role, conn io.ReadWriter) error {
channel := NewGOBChannel(conn)
sends := make(chan any)
receives := make(chan any)
errors := make(chan error)
channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
channel.ReceiveAsync(receives, errors)
select {
case <-time.After(10 * time.Second):
log.Println("PROTOCOLVERSION: timeout")
return fmt.Errorf("Timeout waiting for protocol version")
case err := <-errors:
log.Printf("PROTOCOLVERSION: %v", err)
return err
case protocolVersion := <-receives:
otherVersion := protocolVersion.(ProtocolVersion).Version
if PROTOCOL_VERSION != otherVersion {
switch role {
case Agent:
log.Printf("Protocol version mismatch: agent %d, converge server %d",
PROTOCOL_VERSION, otherVersion)
case ConvergeServer:
log.Printf("Protocol version mismatch: agent %d, converge server %d",
otherVersion, PROTOCOL_VERSION)
}
return fmt.Errorf("Protocol version mismatch")
}
log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version)
return nil
}
}

View File

@ -1,92 +0,0 @@
package comms
import (
"encoding/binary"
"fmt"
"io"
"log"
"net"
"time"
)
const protocol_version = 1
// this file contains utilities for the binary protocol as well as a listener that wraps an existing
// listener and does an exchange with the client as part of accepting the connection before
// passing it on the application. This is used to pass metadata from converge server to the ssh agent
// so that messages from agent to converge server can be correlated with the client ssh session.
func SendInt(w io.Writer, val int) error {
val32 := int32(val)
return binary.Write(w, binary.BigEndian, val32)
}
func ReceiveInt(r io.Reader) (int, error) {
var val int32
err := binary.Read(r, binary.BigEndian, &val)
return int(val), err
}
type AgentListener struct {
decorated net.Listener
}
func NewAgentListener(listener net.Listener) AgentListener {
return AgentListener{decorated: listener}
}
func ExchangeProtocolVersion(version int, conn io.ReadWriter) (int, error) {
errors := make(chan error)
values := make(chan int)
go func() {
err := SendInt(conn, version)
if err != nil {
errors <- err
}
}()
go func() {
val, err := ReceiveInt(conn)
if err != nil {
errors <- err
} else {
values <- val
}
}()
select {
case err := <-errors:
log.Printf("Error exchanging protocol version %v", err)
return 0, err
case <-time.After(10 * time.Second):
log.Println("Timeout exchanging protocol version")
return 0, fmt.Errorf("Timeout echangeing protocol version with converge server")
case val := <-values:
log.Printf("ExchangeProtocolVersion: DEBUG: Got value %v", val)
return val, nil
}
}
func (listener AgentListener) Accept() (net.Conn, error) {
conn, err := listener.decorated.Accept()
if err != nil {
return nil, err
}
_, err = ExchangeProtocolVersion(99, conn)
if err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func (listener AgentListener) Close() error {
return listener.decorated.Close()
}
func (listener AgentListener) Addr() net.Addr {
return listener.decorated.Addr()
}

View File

@ -8,6 +8,12 @@ import (
"time"
)
const PROTOCOL_VERSION = 1
func init() {
RegisterEventsWithGob()
}
// Client to server events
type AgentInfo struct {
@ -32,11 +38,20 @@ type HeartBeat struct {
// Message sent from converge server to agent
type ProtocolVersion struct {
Version int
}
type UserPassword struct {
Username string
Password string
}
type ConnectionInfo struct {
ConnectionId int
UserPassword UserPassword
}
// Generic wrapper message required to send messages of arbitrary type
type ConvergeMessage struct {
@ -71,7 +86,9 @@ func RegisterEventsWithGob() {
gob.Register(HeartBeat{})
// ConvergeServer to Agent
gob.Register(ProtocolVersion{})
gob.Register(UserPassword{})
gob.Register(ConnectionInfo{})
// Wrapper event.
gob.Register(ConvergeMessage{})

62
pkg/comms/gobchannel.go Normal file
View File

@ -0,0 +1,62 @@
package comms
import (
"encoding/gob"
"io"
"log"
)
type GOBChannel struct {
// can be any connection, including the ssh connnection before it is
// passed on to SSH during initialization of converge to agent communication
Peer io.ReadWriter
Encoder *gob.Encoder
Decoder *gob.Decoder
}
func NewGOBChannel(conn io.ReadWriter) GOBChannel {
return GOBChannel{
Peer: conn,
Encoder: gob.NewEncoder(conn),
Decoder: gob.NewDecoder(conn),
}
}
// Asynchronous send and receive on a single connection is guaranteed to preserver ordering of
// messages. We use asynchronous to void blocking indefinitely or depending on network timeouts.
func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- error) {
go func() {
err := channel.Send(obj)
if err != nil {
errors <- err
} else {
done <- true
}
}()
}
func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) {
go func() {
value, err := channel.Receive()
if err != nil {
errors <- err
} else {
result <- value
}
}()
}
func (channel GOBChannel) Send(object any) error {
err := channel.Encoder.Encode(ConvergeMessage{Value: object})
if err != nil {
log.Printf("Encoding error %v", err)
}
return err
}
func (channel GOBChannel) Receive() (any, error) {
var target ConvergeMessage
err := channel.Decoder.Decode(&target)
return target.Value, err
}

View File

@ -1,33 +0,0 @@
package comms
import (
"encoding/gob"
"log"
"net"
"time"
)
type TCPChannel struct {
// can be any connection, including the ssh connnection before it is
// passed on to SSH during initialization of converge to agent communication
Peer net.Conn
Encoder *gob.Encoder
Decoder *gob.Decoder
}
// Synchronous functions with timeouts and error handling.
func (channel TCPChannel) SendAsync(object any, timeout time.Duration) error {
return nil
}
func (channel TCPChannel) ReceiveAsync(object any, timeout time.Duration) error {
return nil
}
func (channel TCPChannel) Send(object any) error {
err := channel.Encoder.Encode(ConvergeMessage{Value: object})
if err != nil {
log.Printf("Encoding error %v", err)
}
return err
}

View File

@ -171,6 +171,12 @@ func (admin *Admin) RemoveClient(client *Client) error {
func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser,
userPassword comms.UserPassword) error {
defer conn.Close()
err := comms.CheckProtocolVersion(comms.ConvergeServer, conn)
if err != nil {
return err
}
// TODO: remove agent return value
agent, err := admin.addAgent(publicId, conn)
if err != nil {
@ -223,8 +229,6 @@ func (admin *Admin) Connect(publicId string, conn iowrappers.ReadWriteAddrCloser
}()
log.Printf("Connecting client and agent: '%s'\n", publicId)
comms.ExchangeProtocolVersion(1111, client.agent)
iowrappers.SynchronizeStreams(client.client, client.agent)
return nil
}