package comms import ( "encoding/binary" "fmt" "io" "log" "net" "time" ) const protocol_version = 1 // this file contains utilities for the binary protocol as well as a listener that wraps an existing // listener and does an exchange with the client as part of accepting the connection before // passing it on the application. This is used to pass metadata from converge server to the ssh agent // so that messages from agent to converge server can be correlated with the client ssh session. func SendInt(w io.Writer, val int) error { val32 := int32(val) return binary.Write(w, binary.BigEndian, val32) } func ReceiveInt(r io.Reader) (int, error) { var val int32 err := binary.Read(r, binary.BigEndian, &val) return int(val), err } type AgentListener struct { decorated net.Listener } func NewAgentListener(listener net.Listener) AgentListener { return AgentListener{decorated: listener} } func ExchangeProtocolVersion(version int, conn io.ReadWriter) (int, error) { errors := make(chan error) values := make(chan int) go func() { err := SendInt(conn, version) if err != nil { errors <- err } }() go func() { val, err := ReceiveInt(conn) if err != nil { errors <- err } else { values <- val } }() select { case err := <-errors: log.Printf("Error exchanging protocol version %v", err) return 0, err case <-time.After(10 * time.Second): log.Println("Timeout exchanging protocol version") return 0, fmt.Errorf("Timeout echangeing protocol version with converge server") case val := <-values: log.Printf("ExchangeProtocolVersion: DEBUG: Got value %v", val) return val, nil } } func (listener AgentListener) Accept() (net.Conn, error) { conn, err := listener.decorated.Accept() if err != nil { return nil, err } _, err = ExchangeProtocolVersion(99, conn) if err != nil { conn.Close() return nil, err } return conn, nil } func (listener AgentListener) Close() error { return listener.decorated.Close() } func (listener AgentListener) Addr() net.Addr { return listener.decorated.Addr() }