From f82601d07c8825e2e7f00b915cd4a08df8be9da9 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Fri, 26 Jul 2024 22:40:56 +0200 Subject: [PATCH] Lots of refactoring. Now hijacking the ssh connection setup in the listener to exchange some information before passing the connection on to the SSH server. Next step is to do the full exchange of required information and to make it easy some simple Read and Write methods with timeouts are needed that use gob. --- cmd/agent/agent.go | 5 ++- pkg/agent/session.go | 8 ++-- pkg/comms/agentserver.go | 38 +++++------------ pkg/comms/binary.go | 92 ++++++++++++++++++++++++++++++++++++++++ pkg/comms/tcpchannel.go | 33 ++++++++++++++ pkg/converge/admin.go | 6 ++- 6 files changed, 147 insertions(+), 35 deletions(-) create mode 100644 pkg/comms/binary.go create mode 100644 pkg/comms/tcpchannel.go diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 3cf05e0..524ce88 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -291,7 +291,10 @@ func main() { log.Println() agent.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval) - service.Run(commChannel.Session) + + listener := comms.NewAgentListener(commChannel.Session) + + service.Run(listener) } func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) { diff --git a/pkg/agent/session.go b/pkg/agent/session.go index 3bf0136..f195308 100644 --- a/pkg/agent/session.go +++ b/pkg/agent/session.go @@ -101,8 +101,8 @@ func ConfigureAgent(commChannel comms.CommChannel, log.Printf("Agent expires at %s", state.expiryTime(holdFilename).Format(time.DateTime)) - comms.Send(state.commChannel, comms.NewAgentInfo()) - comms.Send(state.commChannel, comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) + state.commChannel.SideChannel.Send(comms.NewAgentInfo()) + state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) go func() { for { @@ -182,7 +182,7 @@ func login(sessionId int, sshSession ssh.Session) { if sessionType == "" { sessionType = "ssh" } - comms.Send(state.commChannel, comms.NewSessionInfo(sessionType)) + state.commChannel.SideChannel.Send(comms.NewSessionInfo(sessionType)) holdFileStats, ok := fileExistsWithStats(holdFilename) if ok { @@ -297,7 +297,7 @@ func holdFileChange() { message += holdFileMessage() messageUsers(message) state.lastExpiryTimmeReported = newExpiryTIme - comms.Send(state.commChannel, comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) + state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) } } diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index 69fa18b..ba7fef2 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -6,21 +6,13 @@ import ( "github.com/hashicorp/yamux" "io" "log" - "net" "time" ) type CommChannel struct { - Peer net.Conn - Encoder *gob.Encoder - Decoder *gob.Decoder - Session *yamux.Session -} - -type AgentListener interface { - AgentInfo(agent AgentInfo) - SessionInfo(session SessionInfo) - ExpiryTimeUpdate(session ExpiryTimeUpdate) + // a separet connection outside of the ssh session + SideChannel TCPChannel + Session *yamux.Session } type Role int @@ -39,7 +31,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { return CommChannel{}, err } commChannel = CommChannel{ - Peer: nil, Session: listener, } case ConvergeServer: @@ -48,7 +39,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { return CommChannel{}, err } commChannel = CommChannel{ - Peer: nil, Session: clientSession, } default: @@ -61,13 +51,13 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { switch role { case Agent: conn, err := commChannel.Session.OpenStream() - commChannel.Peer = conn + commChannel.SideChannel.Peer = conn if err != nil { return CommChannel{}, err } case ConvergeServer: conn, err := commChannel.Session.Accept() - commChannel.Peer = conn + commChannel.SideChannel.Peer = conn if err != nil { return CommChannel{}, err } @@ -77,15 +67,15 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { log.Println("Communication channel between agent and converge server established") RegisterEventsWithGob() - commChannel.Encoder = gob.NewEncoder(commChannel.Peer) - commChannel.Decoder = gob.NewDecoder(commChannel.Peer) + commChannel.SideChannel.Encoder = gob.NewEncoder(commChannel.SideChannel.Peer) + commChannel.SideChannel.Decoder = gob.NewDecoder(commChannel.SideChannel.Peer) // heartbeat if role == Agent { go func() { for { time.Sleep(10 * time.Second) - err := Send(commChannel, HeartBeat{}) + err := commChannel.SideChannel.Send(HeartBeat{}) if err != nil { log.Println("Sending heartbeat to server failed") } @@ -98,15 +88,7 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { // Sending an event to the other side -func Send(commChannel CommChannel, object any) error { - err := commChannel.Encoder.Encode(ConvergeMessage{Value: object}) - if err != nil { - log.Printf("Encoding error %v", err) - } - return err -} - -func ListenForAgentEvents(channel CommChannel, +func ListenForAgentEvents(channel TCPChannel, agentInfo func(agent AgentInfo), sessionInfo func(session SessionInfo), expiryTimeUpdate func(session ExpiryTimeUpdate)) { @@ -145,7 +127,7 @@ func ListenForServerEvents(channel CommChannel, setUsernamePassword func(user UserPassword)) { for { var result ConvergeMessage - err := channel.Decoder.Decode(&result) + err := channel.SideChannel.Decoder.Decode(&result) if err != nil { // TODO more clean solution, need to explicitly close when agent exits. diff --git a/pkg/comms/binary.go b/pkg/comms/binary.go new file mode 100644 index 0000000..13e48d3 --- /dev/null +++ b/pkg/comms/binary.go @@ -0,0 +1,92 @@ +package comms + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "net" + "time" +) + +const protocol_version = 1 + +// this file contains utilities for the binary protocol as well as a listener that wraps an existing +// listener and does an exchange with the client as part of accepting the connection before +// passing it on the application. This is used to pass metadata from converge server to the ssh agent +// so that messages from agent to converge server can be correlated with the client ssh session. + +func SendInt(w io.Writer, val int) error { + val32 := int32(val) + return binary.Write(w, binary.BigEndian, val32) +} + +func ReceiveInt(r io.Reader) (int, error) { + var val int32 + err := binary.Read(r, binary.BigEndian, &val) + return int(val), err +} + +type AgentListener struct { + decorated net.Listener +} + +func NewAgentListener(listener net.Listener) AgentListener { + return AgentListener{decorated: listener} +} + +func ExchangeProtocolVersion(version int, conn io.ReadWriter) (int, error) { + errors := make(chan error) + values := make(chan int) + + go func() { + err := SendInt(conn, version) + if err != nil { + errors <- err + } + }() + go func() { + val, err := ReceiveInt(conn) + if err != nil { + errors <- err + } else { + values <- val + } + + }() + + select { + case err := <-errors: + log.Printf("Error exchanging protocol version %v", err) + return 0, err + case <-time.After(10 * time.Second): + log.Println("Timeout exchanging protocol version") + return 0, fmt.Errorf("Timeout echangeing protocol version with converge server") + case val := <-values: + log.Printf("ExchangeProtocolVersion: DEBUG: Got value %v", val) + return val, nil + } +} + +func (listener AgentListener) Accept() (net.Conn, error) { + conn, err := listener.decorated.Accept() + if err != nil { + return nil, err + } + + _, err = ExchangeProtocolVersion(99, conn) + if err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + +func (listener AgentListener) Close() error { + return listener.decorated.Close() +} + +func (listener AgentListener) Addr() net.Addr { + return listener.decorated.Addr() +} diff --git a/pkg/comms/tcpchannel.go b/pkg/comms/tcpchannel.go new file mode 100644 index 0000000..9eb5ade --- /dev/null +++ b/pkg/comms/tcpchannel.go @@ -0,0 +1,33 @@ +package comms + +import ( + "encoding/gob" + "log" + "net" + "time" +) + +type TCPChannel struct { + // can be any connection, including the ssh connnection before it is + // passed on to SSH during initialization of converge to agent communication + Peer net.Conn + Encoder *gob.Encoder + Decoder *gob.Decoder +} + +// Synchronous functions with timeouts and error handling. +func (channel TCPChannel) SendAsync(object any, timeout time.Duration) error { + return nil +} + +func (channel TCPChannel) ReceiveAsync(object any, timeout time.Duration) error { + return nil +} + +func (channel TCPChannel) Send(object any) error { + err := channel.Encoder.Encode(ConvergeMessage{Value: object}) + if err != nil { + log.Printf("Encoding error %v", err) + } + return err +} diff --git a/pkg/converge/admin.go b/pkg/converge/admin.go index 29b6790..ad302d8 100644 --- a/pkg/converge/admin.go +++ b/pkg/converge/admin.go @@ -181,10 +181,10 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser, }() log.Println("Sending username and password to agent") - comms.Send(agent.commChannel, userPassword) + agent.commChannel.SideChannel.Send(userPassword) go func() { - comms.ListenForAgentEvents(agent.commChannel, + comms.ListenForAgentEvents(agent.commChannel.SideChannel, func(info comms.AgentInfo) { agent.agentInfo = info admin.logStatus() @@ -223,6 +223,8 @@ func (admin *Admin) Connect(publicId string, conn iowrappers.ReadWriteAddrCloser }() log.Printf("Connecting client and agent: '%s'\n", publicId) + comms.ExchangeProtocolVersion(1111, client.agent) + iowrappers.SynchronizeStreams(client.client, client.agent) return nil }