converge/pkg/comms/binary.go
Erik Brakkee f82601d07c Lots of refactoring.
Now hijacking the ssh connection setup in the listener to exchange some information before passing the connection on to the SSH server.

Next step is to do the full exchange of required information and to make it easy some simple Read and Write methods with timeouts are needed that use gob.
2024-07-26 22:40:56 +02:00

93 lines
2.0 KiB
Go

package comms
import (
"encoding/binary"
"fmt"
"io"
"log"
"net"
"time"
)
const protocol_version = 1
// this file contains utilities for the binary protocol as well as a listener that wraps an existing
// listener and does an exchange with the client as part of accepting the connection before
// passing it on the application. This is used to pass metadata from converge server to the ssh agent
// so that messages from agent to converge server can be correlated with the client ssh session.
func SendInt(w io.Writer, val int) error {
val32 := int32(val)
return binary.Write(w, binary.BigEndian, val32)
}
func ReceiveInt(r io.Reader) (int, error) {
var val int32
err := binary.Read(r, binary.BigEndian, &val)
return int(val), err
}
type AgentListener struct {
decorated net.Listener
}
func NewAgentListener(listener net.Listener) AgentListener {
return AgentListener{decorated: listener}
}
func ExchangeProtocolVersion(version int, conn io.ReadWriter) (int, error) {
errors := make(chan error)
values := make(chan int)
go func() {
err := SendInt(conn, version)
if err != nil {
errors <- err
}
}()
go func() {
val, err := ReceiveInt(conn)
if err != nil {
errors <- err
} else {
values <- val
}
}()
select {
case err := <-errors:
log.Printf("Error exchanging protocol version %v", err)
return 0, err
case <-time.After(10 * time.Second):
log.Println("Timeout exchanging protocol version")
return 0, fmt.Errorf("Timeout echangeing protocol version with converge server")
case val := <-values:
log.Printf("ExchangeProtocolVersion: DEBUG: Got value %v", val)
return val, nil
}
}
func (listener AgentListener) Accept() (net.Conn, error) {
conn, err := listener.decorated.Accept()
if err != nil {
return nil, err
}
_, err = ExchangeProtocolVersion(99, conn)
if err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func (listener AgentListener) Close() error {
return listener.decorated.Close()
}
func (listener AgentListener) Addr() net.Addr {
return listener.decorated.Addr()
}