GOB channel for easily and asynchronously using GOB on a single network connection, also dealing with timeouts and errors in a good way.
Protocol version is now checked when the agent connects to the converge server. Next up: sending connection metadata and username password from server to agent and sending environment information back to the server. This means then that the side channel will only be used for expiry time messages and session type with the client id passed in so the converge server can than correlate the results back to the correct channel.
This commit is contained in:
parent
f82601d07c
commit
621bbd8ca6
@ -249,6 +249,11 @@ func main() {
|
||||
wsConn := websocketutil.NewWebSocketConn(conn)
|
||||
defer wsConn.Close()
|
||||
|
||||
err = comms.CheckProtocolVersion(comms.Agent, wsConn)
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
commChannel, err := comms.NewCommChannel(comms.Agent, wsConn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -51,7 +51,7 @@ func printHelp(msg string) {
|
||||
}
|
||||
|
||||
func main() {
|
||||
downloadDir := "downloads"
|
||||
downloadDir := "../static"
|
||||
|
||||
args := os.Args[1:]
|
||||
for len(args) > 0 && strings.HasPrefix(args[0], "-") {
|
||||
|
36
pkg/comms/agentlistener.go
Normal file
36
pkg/comms/agentlistener.go
Normal file
@ -0,0 +1,36 @@
|
||||
package comms
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type AgentListener struct {
|
||||
decorated net.Listener
|
||||
}
|
||||
|
||||
func NewAgentListener(listener net.Listener) AgentListener {
|
||||
return AgentListener{decorated: listener}
|
||||
}
|
||||
|
||||
func (listener AgentListener) Accept() (net.Conn, error) {
|
||||
conn, err := listener.decorated.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//_, err = CheckProtocolVersion(Agent, conn)
|
||||
//if err != nil {
|
||||
// conn.Close()
|
||||
// return nil, err
|
||||
//}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (listener AgentListener) Close() error {
|
||||
return listener.decorated.Close()
|
||||
}
|
||||
|
||||
func (listener AgentListener) Addr() net.Addr {
|
||||
return listener.decorated.Addr()
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package comms
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"github.com/hashicorp/yamux"
|
||||
"io"
|
||||
@ -11,7 +10,7 @@ import (
|
||||
|
||||
type CommChannel struct {
|
||||
// a separet connection outside of the ssh session
|
||||
SideChannel TCPChannel
|
||||
SideChannel GOBChannel
|
||||
Session *yamux.Session
|
||||
}
|
||||
|
||||
@ -51,13 +50,13 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
||||
switch role {
|
||||
case Agent:
|
||||
conn, err := commChannel.Session.OpenStream()
|
||||
commChannel.SideChannel.Peer = conn
|
||||
commChannel.SideChannel = NewGOBChannel(conn)
|
||||
if err != nil {
|
||||
return CommChannel{}, err
|
||||
}
|
||||
case ConvergeServer:
|
||||
conn, err := commChannel.Session.Accept()
|
||||
commChannel.SideChannel.Peer = conn
|
||||
commChannel.SideChannel = NewGOBChannel(conn)
|
||||
if err != nil {
|
||||
return CommChannel{}, err
|
||||
}
|
||||
@ -66,10 +65,6 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
||||
}
|
||||
log.Println("Communication channel between agent and converge server established")
|
||||
|
||||
RegisterEventsWithGob()
|
||||
commChannel.SideChannel.Encoder = gob.NewEncoder(commChannel.SideChannel.Peer)
|
||||
commChannel.SideChannel.Decoder = gob.NewDecoder(commChannel.SideChannel.Peer)
|
||||
|
||||
// heartbeat
|
||||
if role == Agent {
|
||||
go func() {
|
||||
@ -88,7 +83,7 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) {
|
||||
|
||||
// Sending an event to the other side
|
||||
|
||||
func ListenForAgentEvents(channel TCPChannel,
|
||||
func ListenForAgentEvents(channel GOBChannel,
|
||||
agentInfo func(agent AgentInfo),
|
||||
sessionInfo func(session SessionInfo),
|
||||
expiryTimeUpdate func(session ExpiryTimeUpdate)) {
|
||||
@ -144,3 +139,38 @@ func ListenForServerEvents(channel CommChannel,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CheckProtocolVersion(role Role, conn io.ReadWriter) error {
|
||||
channel := NewGOBChannel(conn)
|
||||
|
||||
sends := make(chan any)
|
||||
receives := make(chan any)
|
||||
errors := make(chan error)
|
||||
|
||||
channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors)
|
||||
channel.ReceiveAsync(receives, errors)
|
||||
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
log.Println("PROTOCOLVERSION: timeout")
|
||||
return fmt.Errorf("Timeout waiting for protocol version")
|
||||
case err := <-errors:
|
||||
log.Printf("PROTOCOLVERSION: %v", err)
|
||||
return err
|
||||
case protocolVersion := <-receives:
|
||||
otherVersion := protocolVersion.(ProtocolVersion).Version
|
||||
if PROTOCOL_VERSION != otherVersion {
|
||||
switch role {
|
||||
case Agent:
|
||||
log.Printf("Protocol version mismatch: agent %d, converge server %d",
|
||||
PROTOCOL_VERSION, otherVersion)
|
||||
case ConvergeServer:
|
||||
log.Printf("Protocol version mismatch: agent %d, converge server %d",
|
||||
otherVersion, PROTOCOL_VERSION)
|
||||
}
|
||||
return fmt.Errorf("Protocol version mismatch")
|
||||
}
|
||||
log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -1,92 +0,0 @@
|
||||
package comms
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const protocol_version = 1
|
||||
|
||||
// this file contains utilities for the binary protocol as well as a listener that wraps an existing
|
||||
// listener and does an exchange with the client as part of accepting the connection before
|
||||
// passing it on the application. This is used to pass metadata from converge server to the ssh agent
|
||||
// so that messages from agent to converge server can be correlated with the client ssh session.
|
||||
|
||||
func SendInt(w io.Writer, val int) error {
|
||||
val32 := int32(val)
|
||||
return binary.Write(w, binary.BigEndian, val32)
|
||||
}
|
||||
|
||||
func ReceiveInt(r io.Reader) (int, error) {
|
||||
var val int32
|
||||
err := binary.Read(r, binary.BigEndian, &val)
|
||||
return int(val), err
|
||||
}
|
||||
|
||||
type AgentListener struct {
|
||||
decorated net.Listener
|
||||
}
|
||||
|
||||
func NewAgentListener(listener net.Listener) AgentListener {
|
||||
return AgentListener{decorated: listener}
|
||||
}
|
||||
|
||||
func ExchangeProtocolVersion(version int, conn io.ReadWriter) (int, error) {
|
||||
errors := make(chan error)
|
||||
values := make(chan int)
|
||||
|
||||
go func() {
|
||||
err := SendInt(conn, version)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
val, err := ReceiveInt(conn)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
} else {
|
||||
values <- val
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errors:
|
||||
log.Printf("Error exchanging protocol version %v", err)
|
||||
return 0, err
|
||||
case <-time.After(10 * time.Second):
|
||||
log.Println("Timeout exchanging protocol version")
|
||||
return 0, fmt.Errorf("Timeout echangeing protocol version with converge server")
|
||||
case val := <-values:
|
||||
log.Printf("ExchangeProtocolVersion: DEBUG: Got value %v", val)
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (listener AgentListener) Accept() (net.Conn, error) {
|
||||
conn, err := listener.decorated.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = ExchangeProtocolVersion(99, conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (listener AgentListener) Close() error {
|
||||
return listener.decorated.Close()
|
||||
}
|
||||
|
||||
func (listener AgentListener) Addr() net.Addr {
|
||||
return listener.decorated.Addr()
|
||||
}
|
@ -8,6 +8,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const PROTOCOL_VERSION = 1
|
||||
|
||||
func init() {
|
||||
RegisterEventsWithGob()
|
||||
}
|
||||
|
||||
// Client to server events
|
||||
|
||||
type AgentInfo struct {
|
||||
@ -32,11 +38,20 @@ type HeartBeat struct {
|
||||
|
||||
// Message sent from converge server to agent
|
||||
|
||||
type ProtocolVersion struct {
|
||||
Version int
|
||||
}
|
||||
|
||||
type UserPassword struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type ConnectionInfo struct {
|
||||
ConnectionId int
|
||||
UserPassword UserPassword
|
||||
}
|
||||
|
||||
// Generic wrapper message required to send messages of arbitrary type
|
||||
|
||||
type ConvergeMessage struct {
|
||||
@ -71,7 +86,9 @@ func RegisterEventsWithGob() {
|
||||
gob.Register(HeartBeat{})
|
||||
|
||||
// ConvergeServer to Agent
|
||||
gob.Register(ProtocolVersion{})
|
||||
gob.Register(UserPassword{})
|
||||
gob.Register(ConnectionInfo{})
|
||||
|
||||
// Wrapper event.
|
||||
gob.Register(ConvergeMessage{})
|
||||
|
62
pkg/comms/gobchannel.go
Normal file
62
pkg/comms/gobchannel.go
Normal file
@ -0,0 +1,62 @@
|
||||
package comms
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
type GOBChannel struct {
|
||||
// can be any connection, including the ssh connnection before it is
|
||||
// passed on to SSH during initialization of converge to agent communication
|
||||
Peer io.ReadWriter
|
||||
Encoder *gob.Encoder
|
||||
Decoder *gob.Decoder
|
||||
}
|
||||
|
||||
func NewGOBChannel(conn io.ReadWriter) GOBChannel {
|
||||
return GOBChannel{
|
||||
Peer: conn,
|
||||
Encoder: gob.NewEncoder(conn),
|
||||
Decoder: gob.NewDecoder(conn),
|
||||
}
|
||||
}
|
||||
|
||||
// Asynchronous send and receive on a single connection is guaranteed to preserver ordering of
|
||||
// messages. We use asynchronous to void blocking indefinitely or depending on network timeouts.
|
||||
|
||||
func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- error) {
|
||||
go func() {
|
||||
err := channel.Send(obj)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
} else {
|
||||
done <- true
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) {
|
||||
go func() {
|
||||
value, err := channel.Receive()
|
||||
if err != nil {
|
||||
errors <- err
|
||||
} else {
|
||||
result <- value
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (channel GOBChannel) Send(object any) error {
|
||||
err := channel.Encoder.Encode(ConvergeMessage{Value: object})
|
||||
if err != nil {
|
||||
log.Printf("Encoding error %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (channel GOBChannel) Receive() (any, error) {
|
||||
var target ConvergeMessage
|
||||
err := channel.Decoder.Decode(&target)
|
||||
return target.Value, err
|
||||
}
|
@ -1,33 +0,0 @@
|
||||
package comms
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TCPChannel struct {
|
||||
// can be any connection, including the ssh connnection before it is
|
||||
// passed on to SSH during initialization of converge to agent communication
|
||||
Peer net.Conn
|
||||
Encoder *gob.Encoder
|
||||
Decoder *gob.Decoder
|
||||
}
|
||||
|
||||
// Synchronous functions with timeouts and error handling.
|
||||
func (channel TCPChannel) SendAsync(object any, timeout time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (channel TCPChannel) ReceiveAsync(object any, timeout time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (channel TCPChannel) Send(object any) error {
|
||||
err := channel.Encoder.Encode(ConvergeMessage{Value: object})
|
||||
if err != nil {
|
||||
log.Printf("Encoding error %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
@ -171,6 +171,12 @@ func (admin *Admin) RemoveClient(client *Client) error {
|
||||
func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser,
|
||||
userPassword comms.UserPassword) error {
|
||||
defer conn.Close()
|
||||
|
||||
err := comms.CheckProtocolVersion(comms.ConvergeServer, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: remove agent return value
|
||||
agent, err := admin.addAgent(publicId, conn)
|
||||
if err != nil {
|
||||
@ -223,8 +229,6 @@ func (admin *Admin) Connect(publicId string, conn iowrappers.ReadWriteAddrCloser
|
||||
}()
|
||||
log.Printf("Connecting client and agent: '%s'\n", publicId)
|
||||
|
||||
comms.ExchangeProtocolVersion(1111, client.agent)
|
||||
|
||||
iowrappers.SynchronizeStreams(client.client, client.agent)
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user