diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 3cf05e0..524ce88 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -291,7 +291,10 @@ func main() { log.Println() agent.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval) - service.Run(commChannel.Session) + + listener := comms.NewAgentListener(commChannel.Session) + + service.Run(listener) } func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) { diff --git a/pkg/agent/session.go b/pkg/agent/session.go index 3bf0136..f195308 100644 --- a/pkg/agent/session.go +++ b/pkg/agent/session.go @@ -101,8 +101,8 @@ func ConfigureAgent(commChannel comms.CommChannel, log.Printf("Agent expires at %s", state.expiryTime(holdFilename).Format(time.DateTime)) - comms.Send(state.commChannel, comms.NewAgentInfo()) - comms.Send(state.commChannel, comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) + state.commChannel.SideChannel.Send(comms.NewAgentInfo()) + state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) go func() { for { @@ -182,7 +182,7 @@ func login(sessionId int, sshSession ssh.Session) { if sessionType == "" { sessionType = "ssh" } - comms.Send(state.commChannel, comms.NewSessionInfo(sessionType)) + state.commChannel.SideChannel.Send(comms.NewSessionInfo(sessionType)) holdFileStats, ok := fileExistsWithStats(holdFilename) if ok { @@ -297,7 +297,7 @@ func holdFileChange() { message += holdFileMessage() messageUsers(message) state.lastExpiryTimmeReported = newExpiryTIme - comms.Send(state.commChannel, comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) + state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) } } diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index 69fa18b..ba7fef2 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -6,21 +6,13 @@ import ( "github.com/hashicorp/yamux" "io" "log" - "net" "time" ) type CommChannel struct { - Peer net.Conn - Encoder *gob.Encoder - Decoder *gob.Decoder - Session *yamux.Session -} - -type AgentListener interface { - AgentInfo(agent AgentInfo) - SessionInfo(session SessionInfo) - ExpiryTimeUpdate(session ExpiryTimeUpdate) + // a separet connection outside of the ssh session + SideChannel TCPChannel + Session *yamux.Session } type Role int @@ -39,7 +31,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { return CommChannel{}, err } commChannel = CommChannel{ - Peer: nil, Session: listener, } case ConvergeServer: @@ -48,7 +39,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { return CommChannel{}, err } commChannel = CommChannel{ - Peer: nil, Session: clientSession, } default: @@ -61,13 +51,13 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { switch role { case Agent: conn, err := commChannel.Session.OpenStream() - commChannel.Peer = conn + commChannel.SideChannel.Peer = conn if err != nil { return CommChannel{}, err } case ConvergeServer: conn, err := commChannel.Session.Accept() - commChannel.Peer = conn + commChannel.SideChannel.Peer = conn if err != nil { return CommChannel{}, err } @@ -77,15 +67,15 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { log.Println("Communication channel between agent and converge server established") RegisterEventsWithGob() - commChannel.Encoder = gob.NewEncoder(commChannel.Peer) - commChannel.Decoder = gob.NewDecoder(commChannel.Peer) + commChannel.SideChannel.Encoder = gob.NewEncoder(commChannel.SideChannel.Peer) + commChannel.SideChannel.Decoder = gob.NewDecoder(commChannel.SideChannel.Peer) // heartbeat if role == Agent { go func() { for { time.Sleep(10 * time.Second) - err := Send(commChannel, HeartBeat{}) + err := commChannel.SideChannel.Send(HeartBeat{}) if err != nil { log.Println("Sending heartbeat to server failed") } @@ -98,15 +88,7 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { // Sending an event to the other side -func Send(commChannel CommChannel, object any) error { - err := commChannel.Encoder.Encode(ConvergeMessage{Value: object}) - if err != nil { - log.Printf("Encoding error %v", err) - } - return err -} - -func ListenForAgentEvents(channel CommChannel, +func ListenForAgentEvents(channel TCPChannel, agentInfo func(agent AgentInfo), sessionInfo func(session SessionInfo), expiryTimeUpdate func(session ExpiryTimeUpdate)) { @@ -145,7 +127,7 @@ func ListenForServerEvents(channel CommChannel, setUsernamePassword func(user UserPassword)) { for { var result ConvergeMessage - err := channel.Decoder.Decode(&result) + err := channel.SideChannel.Decoder.Decode(&result) if err != nil { // TODO more clean solution, need to explicitly close when agent exits. diff --git a/pkg/comms/binary.go b/pkg/comms/binary.go new file mode 100644 index 0000000..13e48d3 --- /dev/null +++ b/pkg/comms/binary.go @@ -0,0 +1,92 @@ +package comms + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "net" + "time" +) + +const protocol_version = 1 + +// this file contains utilities for the binary protocol as well as a listener that wraps an existing +// listener and does an exchange with the client as part of accepting the connection before +// passing it on the application. This is used to pass metadata from converge server to the ssh agent +// so that messages from agent to converge server can be correlated with the client ssh session. + +func SendInt(w io.Writer, val int) error { + val32 := int32(val) + return binary.Write(w, binary.BigEndian, val32) +} + +func ReceiveInt(r io.Reader) (int, error) { + var val int32 + err := binary.Read(r, binary.BigEndian, &val) + return int(val), err +} + +type AgentListener struct { + decorated net.Listener +} + +func NewAgentListener(listener net.Listener) AgentListener { + return AgentListener{decorated: listener} +} + +func ExchangeProtocolVersion(version int, conn io.ReadWriter) (int, error) { + errors := make(chan error) + values := make(chan int) + + go func() { + err := SendInt(conn, version) + if err != nil { + errors <- err + } + }() + go func() { + val, err := ReceiveInt(conn) + if err != nil { + errors <- err + } else { + values <- val + } + + }() + + select { + case err := <-errors: + log.Printf("Error exchanging protocol version %v", err) + return 0, err + case <-time.After(10 * time.Second): + log.Println("Timeout exchanging protocol version") + return 0, fmt.Errorf("Timeout echangeing protocol version with converge server") + case val := <-values: + log.Printf("ExchangeProtocolVersion: DEBUG: Got value %v", val) + return val, nil + } +} + +func (listener AgentListener) Accept() (net.Conn, error) { + conn, err := listener.decorated.Accept() + if err != nil { + return nil, err + } + + _, err = ExchangeProtocolVersion(99, conn) + if err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + +func (listener AgentListener) Close() error { + return listener.decorated.Close() +} + +func (listener AgentListener) Addr() net.Addr { + return listener.decorated.Addr() +} diff --git a/pkg/comms/tcpchannel.go b/pkg/comms/tcpchannel.go new file mode 100644 index 0000000..9eb5ade --- /dev/null +++ b/pkg/comms/tcpchannel.go @@ -0,0 +1,33 @@ +package comms + +import ( + "encoding/gob" + "log" + "net" + "time" +) + +type TCPChannel struct { + // can be any connection, including the ssh connnection before it is + // passed on to SSH during initialization of converge to agent communication + Peer net.Conn + Encoder *gob.Encoder + Decoder *gob.Decoder +} + +// Synchronous functions with timeouts and error handling. +func (channel TCPChannel) SendAsync(object any, timeout time.Duration) error { + return nil +} + +func (channel TCPChannel) ReceiveAsync(object any, timeout time.Duration) error { + return nil +} + +func (channel TCPChannel) Send(object any) error { + err := channel.Encoder.Encode(ConvergeMessage{Value: object}) + if err != nil { + log.Printf("Encoding error %v", err) + } + return err +} diff --git a/pkg/converge/admin.go b/pkg/converge/admin.go index 29b6790..ad302d8 100644 --- a/pkg/converge/admin.go +++ b/pkg/converge/admin.go @@ -181,10 +181,10 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser, }() log.Println("Sending username and password to agent") - comms.Send(agent.commChannel, userPassword) + agent.commChannel.SideChannel.Send(userPassword) go func() { - comms.ListenForAgentEvents(agent.commChannel, + comms.ListenForAgentEvents(agent.commChannel.SideChannel, func(info comms.AgentInfo) { agent.agentInfo = info admin.logStatus() @@ -223,6 +223,8 @@ func (admin *Admin) Connect(publicId string, conn iowrappers.ReadWriteAddrCloser }() log.Printf("Connecting client and agent: '%s'\n", publicId) + comms.ExchangeProtocolVersion(1111, client.agent) + iowrappers.SynchronizeStreams(client.client, client.agent) return nil }