Lots of refactoring.
Now hijacking the ssh connection setup in the listener to exchange some information before passing the connection on to the SSH server. Next step is to do the full exchange of required information and to make it easy some simple Read and Write methods with timeouts are needed that use gob.
This commit is contained in:
parent
4d660a6805
commit
d3cbf8388f
@ -291,7 +291,10 @@ func main() {
|
|||||||
log.Println()
|
log.Println()
|
||||||
|
|
||||||
agent.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval)
|
agent.ConfigureAgent(commChannel, advanceWarningTime, agentExpriryTime, tickerInterval)
|
||||||
service.Run(commChannel.Session)
|
|
||||||
|
listener := comms.NewAgentListener(commChannel.Session)
|
||||||
|
|
||||||
|
service.Run(listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) {
|
func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) {
|
||||||
|
@ -101,8 +101,8 @@ func ConfigureAgent(commChannel comms.CommChannel,
|
|||||||
log.Printf("Agent expires at %s",
|
log.Printf("Agent expires at %s",
|
||||||
state.expiryTime(holdFilename).Format(time.DateTime))
|
state.expiryTime(holdFilename).Format(time.DateTime))
|
||||||
|
|
||||||
comms.Send(state.commChannel, comms.NewAgentInfo())
|
state.commChannel.SideChannel.Send(comms.NewAgentInfo())
|
||||||
comms.Send(state.commChannel, comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
|
state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
@ -182,7 +182,7 @@ func login(sessionId int, sshSession ssh.Session) {
|
|||||||
if sessionType == "" {
|
if sessionType == "" {
|
||||||
sessionType = "ssh"
|
sessionType = "ssh"
|
||||||
}
|
}
|
||||||
comms.Send(state.commChannel, comms.NewSessionInfo(sessionType))
|
state.commChannel.SideChannel.Send(comms.NewSessionInfo(sessionType))
|
||||||
|
|
||||||
holdFileStats, ok := fileExistsWithStats(holdFilename)
|
holdFileStats, ok := fileExistsWithStats(holdFilename)
|
||||||
if ok {
|
if ok {
|
||||||
@ -297,7 +297,7 @@ func holdFileChange() {
|
|||||||
message += holdFileMessage()
|
message += holdFileMessage()
|
||||||
messageUsers(message)
|
messageUsers(message)
|
||||||
state.lastExpiryTimmeReported = newExpiryTIme
|
state.lastExpiryTimmeReported = newExpiryTIme
|
||||||
comms.Send(state.commChannel, comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
|
state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,21 +6,13 @@ import (
|
|||||||
"github.com/hashicorp/yamux"
|
"github.com/hashicorp/yamux"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CommChannel struct {
|
type CommChannel struct {
|
||||||
Peer net.Conn
|
// a separet connection outside of the ssh session
|
||||||
Encoder *gob.Encoder
|
SideChannel TCPChannel
|
||||||
Decoder *gob.Decoder
|
Session *yamux.Session
|
||||||
Session *yamux.Session
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentListener interface {
|
|
||||||
AgentInfo(agent AgentInfo)
|
|
||||||
SessionInfo(session SessionInfo)
|
|
||||||
ExpiryTimeUpdate(session ExpiryTimeUpdate)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Role int
|
type Role int
|
||||||
@ -39,7 +31,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
|||||||
return CommChannel{}, err
|
return CommChannel{}, err
|
||||||
}
|
}
|
||||||
commChannel = CommChannel{
|
commChannel = CommChannel{
|
||||||
Peer: nil,
|
|
||||||
Session: listener,
|
Session: listener,
|
||||||
}
|
}
|
||||||
case ConvergeServer:
|
case ConvergeServer:
|
||||||
@ -48,7 +39,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
|||||||
return CommChannel{}, err
|
return CommChannel{}, err
|
||||||
}
|
}
|
||||||
commChannel = CommChannel{
|
commChannel = CommChannel{
|
||||||
Peer: nil,
|
|
||||||
Session: clientSession,
|
Session: clientSession,
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -61,13 +51,13 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
|||||||
switch role {
|
switch role {
|
||||||
case Agent:
|
case Agent:
|
||||||
conn, err := commChannel.Session.OpenStream()
|
conn, err := commChannel.Session.OpenStream()
|
||||||
commChannel.Peer = conn
|
commChannel.SideChannel.Peer = conn
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return CommChannel{}, err
|
return CommChannel{}, err
|
||||||
}
|
}
|
||||||
case ConvergeServer:
|
case ConvergeServer:
|
||||||
conn, err := commChannel.Session.Accept()
|
conn, err := commChannel.Session.Accept()
|
||||||
commChannel.Peer = conn
|
commChannel.SideChannel.Peer = conn
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return CommChannel{}, err
|
return CommChannel{}, err
|
||||||
}
|
}
|
||||||
@ -77,15 +67,15 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
|||||||
log.Println("Communication channel between agent and converge server established")
|
log.Println("Communication channel between agent and converge server established")
|
||||||
|
|
||||||
RegisterEventsWithGob()
|
RegisterEventsWithGob()
|
||||||
commChannel.Encoder = gob.NewEncoder(commChannel.Peer)
|
commChannel.SideChannel.Encoder = gob.NewEncoder(commChannel.SideChannel.Peer)
|
||||||
commChannel.Decoder = gob.NewDecoder(commChannel.Peer)
|
commChannel.SideChannel.Decoder = gob.NewDecoder(commChannel.SideChannel.Peer)
|
||||||
|
|
||||||
// heartbeat
|
// heartbeat
|
||||||
if role == Agent {
|
if role == Agent {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
time.Sleep(10 * time.Second)
|
time.Sleep(10 * time.Second)
|
||||||
err := Send(commChannel, HeartBeat{})
|
err := commChannel.SideChannel.Send(HeartBeat{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Sending heartbeat to server failed")
|
log.Println("Sending heartbeat to server failed")
|
||||||
}
|
}
|
||||||
@ -98,15 +88,7 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
|||||||
|
|
||||||
// Sending an event to the other side
|
// Sending an event to the other side
|
||||||
|
|
||||||
func Send(commChannel CommChannel, object any) error {
|
func ListenForAgentEvents(channel TCPChannel,
|
||||||
err := commChannel.Encoder.Encode(ConvergeMessage{Value: object})
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Encoding error %v", err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListenForAgentEvents(channel CommChannel,
|
|
||||||
agentInfo func(agent AgentInfo),
|
agentInfo func(agent AgentInfo),
|
||||||
sessionInfo func(session SessionInfo),
|
sessionInfo func(session SessionInfo),
|
||||||
expiryTimeUpdate func(session ExpiryTimeUpdate)) {
|
expiryTimeUpdate func(session ExpiryTimeUpdate)) {
|
||||||
@ -145,7 +127,7 @@ func ListenForServerEvents(channel CommChannel,
|
|||||||
setUsernamePassword func(user UserPassword)) {
|
setUsernamePassword func(user UserPassword)) {
|
||||||
for {
|
for {
|
||||||
var result ConvergeMessage
|
var result ConvergeMessage
|
||||||
err := channel.Decoder.Decode(&result)
|
err := channel.SideChannel.Decoder.Decode(&result)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO more clean solution, need to explicitly close when agent exits.
|
// TODO more clean solution, need to explicitly close when agent exits.
|
||||||
|
92
pkg/comms/binary.go
Normal file
92
pkg/comms/binary.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
package comms
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const protocol_version = 1
|
||||||
|
|
||||||
|
// this file contains utilities for the binary protocol as well as a listener that wraps an existing
|
||||||
|
// listener and does an exchange with the client as part of accepting the connection before
|
||||||
|
// passing it on the application. This is used to pass metadata from converge server to the ssh agent
|
||||||
|
// so that messages from agent to converge server can be correlated with the client ssh session.
|
||||||
|
|
||||||
|
func SendInt(w io.Writer, val int) error {
|
||||||
|
val32 := int32(val)
|
||||||
|
return binary.Write(w, binary.BigEndian, val32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReceiveInt(r io.Reader) (int, error) {
|
||||||
|
var val int32
|
||||||
|
err := binary.Read(r, binary.BigEndian, &val)
|
||||||
|
return int(val), err
|
||||||
|
}
|
||||||
|
|
||||||
|
type AgentListener struct {
|
||||||
|
decorated net.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAgentListener(listener net.Listener) AgentListener {
|
||||||
|
return AgentListener{decorated: listener}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExchangeProtocolVersion(version int, conn io.ReadWriter) (int, error) {
|
||||||
|
errors := make(chan error)
|
||||||
|
values := make(chan int)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := SendInt(conn, version)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
val, err := ReceiveInt(conn)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
} else {
|
||||||
|
values <- val
|
||||||
|
}
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errors:
|
||||||
|
log.Printf("Error exchanging protocol version %v", err)
|
||||||
|
return 0, err
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
log.Println("Timeout exchanging protocol version")
|
||||||
|
return 0, fmt.Errorf("Timeout echangeing protocol version with converge server")
|
||||||
|
case val := <-values:
|
||||||
|
log.Printf("ExchangeProtocolVersion: DEBUG: Got value %v", val)
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (listener AgentListener) Accept() (net.Conn, error) {
|
||||||
|
conn, err := listener.decorated.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = ExchangeProtocolVersion(99, conn)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (listener AgentListener) Close() error {
|
||||||
|
return listener.decorated.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (listener AgentListener) Addr() net.Addr {
|
||||||
|
return listener.decorated.Addr()
|
||||||
|
}
|
33
pkg/comms/tcpchannel.go
Normal file
33
pkg/comms/tcpchannel.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package comms
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/gob"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TCPChannel struct {
|
||||||
|
// can be any connection, including the ssh connnection before it is
|
||||||
|
// passed on to SSH during initialization of converge to agent communication
|
||||||
|
Peer net.Conn
|
||||||
|
Encoder *gob.Encoder
|
||||||
|
Decoder *gob.Decoder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Synchronous functions with timeouts and error handling.
|
||||||
|
func (channel TCPChannel) SendAsync(object any, timeout time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel TCPChannel) ReceiveAsync(object any, timeout time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel TCPChannel) Send(object any) error {
|
||||||
|
err := channel.Encoder.Encode(ConvergeMessage{Value: object})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Encoding error %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
@ -181,10 +181,10 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser,
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
log.Println("Sending username and password to agent")
|
log.Println("Sending username and password to agent")
|
||||||
comms.Send(agent.commChannel, userPassword)
|
agent.commChannel.SideChannel.Send(userPassword)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
comms.ListenForAgentEvents(agent.commChannel,
|
comms.ListenForAgentEvents(agent.commChannel.SideChannel,
|
||||||
func(info comms.AgentInfo) {
|
func(info comms.AgentInfo) {
|
||||||
agent.agentInfo = info
|
agent.agentInfo = info
|
||||||
admin.logStatus()
|
admin.logStatus()
|
||||||
@ -223,6 +223,8 @@ func (admin *Admin) Connect(publicId string, conn iowrappers.ReadWriteAddrCloser
|
|||||||
}()
|
}()
|
||||||
log.Printf("Connecting client and agent: '%s'\n", publicId)
|
log.Printf("Connecting client and agent: '%s'\n", publicId)
|
||||||
|
|
||||||
|
comms.ExchangeProtocolVersion(1111, client.agent)
|
||||||
|
|
||||||
iowrappers.SynchronizeStreams(client.client, client.agent)
|
iowrappers.SynchronizeStreams(client.client, client.agent)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user