From ff97c1ccd202ebdaff0cf4b5dcf97813315c7d03 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Thu, 1 Aug 2024 20:22:41 +0200 Subject: [PATCH] Now by default wsproxy uses a specific protocol to establish connections to the server. It does this by adding the ?wsproxy query parameter. The server then sends it the protocol version and the client connection info describing whether an agent was found or not. This improves usability for users. With the --raw option it bypasses this query parameter and wsproxy then works in the old way as a simple stdio-websocket connector. It then still works with converge server but can also be used for simple websocket troubleshooting. --- cmd/converge/converge.go | 3 +- cmd/wsproxy/wsproxy.go | 91 +++++++++++++++++++++++++++++++----- pkg/comms/events.go | 18 +++++-- pkg/server/converge/admin.go | 36 +++++++++++++- 4 files changed, 130 insertions(+), 18 deletions(-) diff --git a/cmd/converge/converge.go b/cmd/converge/converge.go index 39370d8..bae64aa 100644 --- a/cmd/converge/converge.go +++ b/cmd/converge/converge.go @@ -127,8 +127,9 @@ func main() { log.Printf("Cannot parse public id from url: '%v'\n", err) return } + _, wsProxyMode := r.URL.Query()["wsproxy"] log.Printf("Got client connection: '%s'\n", publicId) - err = admin.Connect(publicId, conn) + err = admin.Connect(wsProxyMode, publicId, conn) if err != nil { log.Printf("Error %v\n", err) } diff --git a/cmd/wsproxy/wsproxy.go b/cmd/wsproxy/wsproxy.go index 8ad123d..c38d894 100644 --- a/cmd/wsproxy/wsproxy.go +++ b/cmd/wsproxy/wsproxy.go @@ -1,6 +1,7 @@ package main import ( + "converge/pkg/comms" "converge/pkg/support/iowrappers" "converge/pkg/support/websocketutil" "crypto/tls" @@ -9,7 +10,9 @@ import ( "log" "net" "net/http" + "net/url" "os" + "strings" "time" ) @@ -29,7 +32,17 @@ func (stdio Stdio) Write(b []byte) (n int, err error) { return os.Stdout.Write(b) } -func main() { +func getArg(args []string) (value string, ret []string) { + if len(args) < 2 { + printHelp(fmt.Sprintf("The '%s' option expects an argument", args[0])) + } + return args[1], args[1:] +} + +func printHelp(msg string) { + if msg != "" { + fmt.Fprintf(os.Stderr, "ERROR: %s\n\n", msg) + } usage := "Usage: wsproxy [--insecure] ws[s]://[:port]/client/\n\n" + "\n" + "Here is the rendez-vous id of a continuous integration job\n" + @@ -38,22 +51,50 @@ func main() { "\n" + " ssh -oServerAliveInterval=10 -oProxyCommand='wsproxy ws[s]://[:port]/client/' abc@localhost\n" + "\n" + - "This latssh connect through wsproxy tocalhost\n" + "This latssh connect through wsproxy tocalhost\n" + + "\n" + + "Options:\n" + + "\n" + + "--insecure: Skip certificate validation when used over a secure connection.\n" + + "--raw: Just use wsproxy as a raw stdio to websocket proxy, this disable messages\n" + + " from the server.\n" + fmt.Fprintln(os.Stderr, usage) + os.Exit(1) +} + +func main() { + + insecure := false + raw := false args := os.Args[1:] - insecure := false - - if len(args) == 2 && args[0] == "--insecure" { - insecure = true + for len(args) > 0 && strings.HasPrefix(args[0], "-") { + switch args[0] { + case "--insecure": + insecure = true + case "--raw": + raw = true + default: + printHelp("Unknown option " + args[0]) + } args = args[1:] } if len(args) != 1 { - fmt.Fprintf(os.Stderr, usage) - os.Exit(1) + printHelp("") } wsURL := args[0] + urlParsed, err := url.Parse(wsURL) + if err != nil { + printHelp(fmt.Sprintf("Url '%s' is not valid", wsURL)) + } + if !raw { + if len(urlParsed.Query()) > 0 { + printHelp("When not used in raw mode, wsproxy does not accept query parammeters in the URL") + } + wsURL += "?wsproxy" + } dialer := websocket.Dialer{ Proxy: http.ProxyFromEnvironment, @@ -63,9 +104,6 @@ func main() { dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} } _wsConn, _, err := dialer.Dial(wsURL, nil) - if err != nil { - panic(err) - } if err != nil { log.Println("WebSocket connection error:", err) panic(err) @@ -73,5 +111,34 @@ func main() { wsConn := websocketutil.NewWebSocketConn(_wsConn, false) defer wsConn.Close() - iowrappers.SynchronizeStreams(wsURL+" -- stdio", wsConn, Stdio{}) + if !raw { + channel := comms.NewGOBChannel(wsConn) + // receive protocol version + protocolVersion, err := comms.ReceiveWithTimeout[comms.ProtocolVersion](channel) + if err != nil { + log.Printf("Error receiving protocol version %v", err) + os.Exit(1) + } + if protocolVersion.Version != comms.PROTOCOL_VERSION { + log.Printf("Protocol version mismmatch: client %d, server %d", + comms.PROTOCOL_VERSION, protocolVersion.Version) + os.Exit(1) + } + // receive confirmation about the agent id. + clientConnectionInfo, err := comms.ReceiveWithTimeout[comms.ClientConnectionInfo](channel) + if err != nil { + log.Printf("Error receiving client connection info: %v", err) + os.Exit(1) + } + if clientConnectionInfo.Ok { + log.Printf("Client connection was accepted, agent is available.") + } else { + log.Printf("Error reported by server: %v", clientConnectionInfo.Message) + os.Exit(1) + } + } + + dataConnection := wsConn + + iowrappers.SynchronizeStreams(wsURL+" -- stdio", dataConnection, Stdio{}) } diff --git a/pkg/comms/events.go b/pkg/comms/events.go index 6f21486..9f028b7 100644 --- a/pkg/comms/events.go +++ b/pkg/comms/events.go @@ -8,13 +8,13 @@ import ( "time" ) -const PROTOCOL_VERSION = 1 +const PROTOCOL_VERSION = 2 func init() { RegisterEventsWithGob() } -// Client to server events +// Agent to server events type AgentInfo struct { Username string @@ -104,10 +104,22 @@ func RegisterEventsWithGob() { gob.Register(ExpiryTimeUpdate{}) gob.Register(HeartBeat{}) - // ConvergeServer to Agent + // ConvergeServer to Agent and client gob.Register(ProtocolVersion{}) + + // ConvergeServer to Agent gob.Register(UserPassword{}) + // ConvergeServer to Client + gob.Register(ClientConnectionInfo{}) + // Wrapper event. gob.Register(ConvergeMessage{}) } + +// Server to client events + +type ClientConnectionInfo struct { + Ok bool + Message string +} diff --git a/pkg/server/converge/admin.go b/pkg/server/converge/admin.go index 8426467..930ef02 100644 --- a/pkg/server/converge/admin.go +++ b/pkg/server/converge/admin.go @@ -189,7 +189,7 @@ func (admin *Admin) addClient(publicId string, clientConn iowrappers2.ReadWriteA agent := admin.agents[publicId] if agent == nil { // we should setup on-demend connections ot agents later. - return nil, fmt.Errorf("No agent found for PublicId '%s'", publicId) + return nil, fmt.Errorf("No agent found for rendez-vous id '%s'", publicId) } agentConn, err := admin.getAgentConnection(agent) @@ -317,16 +317,48 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser, return nil } -func (admin *Admin) Connect(publicId string, conn iowrappers2.ReadWriteAddrCloser) error { +func (admin *Admin) Connect(wsProxyMode bool, publicId string, conn iowrappers2.ReadWriteAddrCloser) error { defer conn.Close() + + log.Printf("Using wsproxy protocol %v", wsProxyMode) + channel := comms.NewGOBChannel(conn) + if wsProxyMode { + err := comms.SendWithTimeout( + channel, + comms.ProtocolVersion{ + Version: comms.PROTOCOL_VERSION, + }) + if err != nil { + log.Printf("Error sending protocol version to client %v", err) + return err + } + } + client, err := admin.addClient(publicId, conn) if err != nil { + if wsProxyMode { + _ = comms.SendWithTimeout(channel, + comms.ClientConnectionInfo{ + Ok: false, + Message: err.Error(), + }) + } return err } defer func() { admin.RemoveClient(client) }() log.Printf("Connecting client and agent: '%s'\n", publicId) + if wsProxyMode { + err = comms.SendWithTimeout(channel, + comms.ClientConnectionInfo{ + Ok: true, + Message: "Connecting to agent", + }) + if err != nil { + return fmt.Errorf("Error sending connection info to client: %v", err) + } + } iowrappers2.SynchronizeStreams("client -- agent", client.client, client.agent) return nil