From a6a0f287dcf44f5d7d4ae0e14bc3ed5b31955121 Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Thu, 18 Jul 2024 20:58:15 +0200 Subject: [PATCH] generalizing websocket connection to reader/writer. --- cmd/tcptows/tcptows.go | 86 +++++++++++++++++++++++++++++++----------- cmd/wstotcp/wstotcp.go | 22 +++++------ 2 files changed, 74 insertions(+), 34 deletions(-) diff --git a/cmd/tcptows/tcptows.go b/cmd/tcptows/tcptows.go index 28971cf..dddf8c0 100644 --- a/cmd/tcptows/tcptows.go +++ b/cmd/tcptows/tcptows.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "io" "log" "net" @@ -8,46 +9,49 @@ import ( "github.com/gorilla/websocket" ) -func handleConnection(conn net.Conn, wsURL string) { - defer conn.Close() +func handleConnection(tcpConn net.Conn, wsURL string) { + defer tcpConn.Close() - wsConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + wsConn := NewWebSocketConn(conn) if err != nil { log.Println("WebSocket connection error:", err) return } defer wsConn.Close() + waitChannel := make(chan bool) + go func() { + defer func() { + waitChannel <- true + }() + _, _ = io.Copy(tcpConn, wsConn) + // TODO print error + }() + + go func() { + defer func() { + waitChannel <- true + }() for { - _, message, err := wsConn.ReadMessage() + buffer := make([]byte, 1024) + n, err := tcpConn.Read(buffer) if err != nil { - log.Println("WebSocket read error:", err) + if err != io.EOF { + log.Println("TCP read error:", err) + } return } - _, err = conn.Write(message) + err = conn.WriteMessage(websocket.BinaryMessage, buffer[:n]) if err != nil { - log.Println("TCP write error:", err) + log.Println("WebSocket write error:", err) return } } }() - for { - buffer := make([]byte, 1024) - n, err := conn.Read(buffer) - if err != nil { - if err != io.EOF { - log.Println("TCP read error:", err) - } - return - } - err = wsConn.WriteMessage(websocket.BinaryMessage, buffer[:n]) - if err != nil { - log.Println("WebSocket write error:", err) - return - } - } + <-waitChannel } func main() { @@ -72,3 +76,41 @@ func main() { go handleConnection(conn, wsURL) } } + +type WebSocketConn struct { + conn *websocket.Conn + buf []byte +} + +func NewWebSocketConn(conn *websocket.Conn) *WebSocketConn { + return &WebSocketConn{conn: conn} +} + +func (websocketConn *WebSocketConn) Read(p []byte) (n int, err error) { + if len(websocketConn.buf) == 0 { + _, message, err := websocketConn.conn.ReadMessage() + fmt.Println("Got message ", message) + if err != nil { + return 0, err + } + websocketConn.buf = message + } + + n = copy(p, websocketConn.buf) + websocketConn.buf = websocketConn.buf[n:] + + log.Println("Read ", n, " bytes") + return n, err +} + +func (websocketConn *WebSocketConn) Write(p []byte) (n int, err error) { + err = websocketConn.conn.WriteMessage(websocket.BinaryMessage, p) + if err != nil { + n = len(p) + } + return n, err +} + +func (websocketConn *WebSocketConn) Close() error { + return websocketConn.conn.Close() +} diff --git a/cmd/wstotcp/wstotcp.go b/cmd/wstotcp/wstotcp.go index 15c3f43..bdd107b 100644 --- a/cmd/wstotcp/wstotcp.go +++ b/cmd/wstotcp/wstotcp.go @@ -2,11 +2,10 @@ package main import ( "fmt" + "github.com/gorilla/websocket" "log" "net" "net/http" - - "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ @@ -14,26 +13,25 @@ var upgrader = websocket.Upgrader{ WriteBufferSize: 1024, } -var tcpConn net.Conn - func main() { // Connect to TCP server - var err error - tcpConn, err = net.Dial("tcp", "localhost:2222") // Replace with your TCP server address - if err != nil { - log.Fatal("Error connecting to TCP server:", err) - } - defer tcpConn.Close() // Set up WebSocket handler - http.HandleFunc("/ws", handleWebSocket) + http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { + tcpConn, err := net.Dial("tcp", "localhost:2222") // Replace with your TCP server address + if err != nil { + log.Fatal("Error connecting to TCP server:", err) + } + defer tcpConn.Close() + handleWebSocket(w, r, tcpConn) + }) // Start HTTP server fmt.Println("WebSocket server listening on :8000") log.Fatal(http.ListenAndServe(":8000", nil)) } -func handleWebSocket(w http.ResponseWriter, r *http.Request) { +func handleWebSocket(w http.ResponseWriter, r *http.Request, tcpConn net.Conn) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println("Error upgrading to WebSocket:", err)