Lots of refactoring.

Now hijacking the ssh connection setup in the listener to exchange some information before passing the connection on to the SSH server.

Next step is to do the full exchange of required information and to make it easy some simple Read and Write methods with timeouts are needed that use gob.
This commit is contained in:
Erik Brakkee 2024-07-26 22:40:56 +02:00
parent 4d660a6805
commit d3cbf8388f
6 changed files with 147 additions and 35 deletions

View File

@ -291,7 +291,10 @@ func main() {
log.Println()
agent.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval)
service.Run(commChannel.Session)
listener := comms.NewAgentListener(commChannel.Session)
service.Run(listener)
}
func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) {

View File

@ -101,8 +101,8 @@ func ConfigureAgent(commChannel comms.CommChannel,
log.Printf("Agent expires at %s",
state.expiryTime(holdFilename).Format(time.DateTime))
comms.Send(state.commChannel, comms.NewAgentInfo())
comms.Send(state.commChannel, comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
state.commChannel.SideChannel.Send(comms.NewAgentInfo())
state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
go func() {
for {
@ -182,7 +182,7 @@ func login(sessionId int, sshSession ssh.Session) {
if sessionType == "" {
sessionType = "ssh"
}
comms.Send(state.commChannel, comms.NewSessionInfo(sessionType))
state.commChannel.SideChannel.Send(comms.NewSessionInfo(sessionType))
holdFileStats, ok := fileExistsWithStats(holdFilename)
if ok {
@ -297,7 +297,7 @@ func holdFileChange() {
message += holdFileMessage()
messageUsers(message)
state.lastExpiryTimmeReported = newExpiryTIme
comms.Send(state.commChannel, comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
}
}

View File

@ -6,23 +6,15 @@ import (
"github.com/hashicorp/yamux"
"io"
"log"
"net"
"time"
)
type CommChannel struct {
Peer net.Conn
Encoder *gob.Encoder
Decoder *gob.Decoder
// a separet connection outside of the ssh session
SideChannel TCPChannel
Session *yamux.Session
}
type AgentListener interface {
AgentInfo(agent AgentInfo)
SessionInfo(session SessionInfo)
ExpiryTimeUpdate(session ExpiryTimeUpdate)
}
type Role int
const (
@ -39,7 +31,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
return CommChannel{}, err
}
commChannel = CommChannel{
Peer: nil,
Session: listener,
}
case ConvergeServer:
@ -48,7 +39,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
return CommChannel{}, err
}
commChannel = CommChannel{
Peer: nil,
Session: clientSession,
}
default:
@ -61,13 +51,13 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
switch role {
case Agent:
conn, err := commChannel.Session.OpenStream()
commChannel.Peer = conn
commChannel.SideChannel.Peer = conn
if err != nil {
return CommChannel{}, err
}
case ConvergeServer:
conn, err := commChannel.Session.Accept()
commChannel.Peer = conn
commChannel.SideChannel.Peer = conn
if err != nil {
return CommChannel{}, err
}
@ -77,15 +67,15 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
log.Println("Communication channel between agent and converge server established")
RegisterEventsWithGob()
commChannel.Encoder = gob.NewEncoder(commChannel.Peer)
commChannel.Decoder = gob.NewDecoder(commChannel.Peer)
commChannel.SideChannel.Encoder = gob.NewEncoder(commChannel.SideChannel.Peer)
commChannel.SideChannel.Decoder = gob.NewDecoder(commChannel.SideChannel.Peer)
// heartbeat
if role == Agent {
go func() {
for {
time.Sleep(10 * time.Second)
err := Send(commChannel, HeartBeat{})
err := commChannel.SideChannel.Send(HeartBeat{})
if err != nil {
log.Println("Sending heartbeat to server failed")
}
@ -98,15 +88,7 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
// Sending an event to the other side
func Send(commChannel CommChannel, object any) error {
err := commChannel.Encoder.Encode(ConvergeMessage{Value: object})
if err != nil {
log.Printf("Encoding error %v", err)
}
return err
}
func ListenForAgentEvents(channel CommChannel,
func ListenForAgentEvents(channel TCPChannel,
agentInfo func(agent AgentInfo),
sessionInfo func(session SessionInfo),
expiryTimeUpdate func(session ExpiryTimeUpdate)) {
@ -145,7 +127,7 @@ func ListenForServerEvents(channel CommChannel,
setUsernamePassword func(user UserPassword)) {
for {
var result ConvergeMessage
err := channel.Decoder.Decode(&result)
err := channel.SideChannel.Decoder.Decode(&result)
if err != nil {
// TODO more clean solution, need to explicitly close when agent exits.

92
pkg/comms/binary.go Normal file
View File

@ -0,0 +1,92 @@
package comms
import (
"encoding/binary"
"fmt"
"io"
"log"
"net"
"time"
)
const protocol_version = 1
// this file contains utilities for the binary protocol as well as a listener that wraps an existing
// listener and does an exchange with the client as part of accepting the connection before
// passing it on the application. This is used to pass metadata from converge server to the ssh agent
// so that messages from agent to converge server can be correlated with the client ssh session.
func SendInt(w io.Writer, val int) error {
val32 := int32(val)
return binary.Write(w, binary.BigEndian, val32)
}
func ReceiveInt(r io.Reader) (int, error) {
var val int32
err := binary.Read(r, binary.BigEndian, &val)
return int(val), err
}
type AgentListener struct {
decorated net.Listener
}
func NewAgentListener(listener net.Listener) AgentListener {
return AgentListener{decorated: listener}
}
func ExchangeProtocolVersion(version int, conn io.ReadWriter) (int, error) {
errors := make(chan error)
values := make(chan int)
go func() {
err := SendInt(conn, version)
if err != nil {
errors <- err
}
}()
go func() {
val, err := ReceiveInt(conn)
if err != nil {
errors <- err
} else {
values <- val
}
}()
select {
case err := <-errors:
log.Printf("Error exchanging protocol version %v", err)
return 0, err
case <-time.After(10 * time.Second):
log.Println("Timeout exchanging protocol version")
return 0, fmt.Errorf("Timeout echangeing protocol version with converge server")
case val := <-values:
log.Printf("ExchangeProtocolVersion: DEBUG: Got value %v", val)
return val, nil
}
}
func (listener AgentListener) Accept() (net.Conn, error) {
conn, err := listener.decorated.Accept()
if err != nil {
return nil, err
}
_, err = ExchangeProtocolVersion(99, conn)
if err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func (listener AgentListener) Close() error {
return listener.decorated.Close()
}
func (listener AgentListener) Addr() net.Addr {
return listener.decorated.Addr()
}

33
pkg/comms/tcpchannel.go Normal file
View File

@ -0,0 +1,33 @@
package comms
import (
"encoding/gob"
"log"
"net"
"time"
)
type TCPChannel struct {
// can be any connection, including the ssh connnection before it is
// passed on to SSH during initialization of converge to agent communication
Peer net.Conn
Encoder *gob.Encoder
Decoder *gob.Decoder
}
// Synchronous functions with timeouts and error handling.
func (channel TCPChannel) SendAsync(object any, timeout time.Duration) error {
return nil
}
func (channel TCPChannel) ReceiveAsync(object any, timeout time.Duration) error {
return nil
}
func (channel TCPChannel) Send(object any) error {
err := channel.Encoder.Encode(ConvergeMessage{Value: object})
if err != nil {
log.Printf("Encoding error %v", err)
}
return err
}

View File

@ -181,10 +181,10 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser,
}()
log.Println("Sending username and password to agent")
comms.Send(agent.commChannel, userPassword)
agent.commChannel.SideChannel.Send(userPassword)
go func() {
comms.ListenForAgentEvents(agent.commChannel,
comms.ListenForAgentEvents(agent.commChannel.SideChannel,
func(info comms.AgentInfo) {
agent.agentInfo = info
admin.logStatus()
@ -223,6 +223,8 @@ func (admin *Admin) Connect(publicId string, conn iowrappers.ReadWriteAddrCloser
}()
log.Printf("Connecting client and agent: '%s'\n", publicId)
comms.ExchangeProtocolVersion(1111, client.agent)
iowrappers.SynchronizeStreams(client.client, client.agent)
return nil
}