From 188d949d658d20742705ee856d86c448bc5fa7de Mon Sep 17 00:00:00 2001 From: Erik Brakkee Date: Thu, 18 Jul 2024 21:14:23 +0200 Subject: [PATCH] moved websocket wrapper to iowrappers package. --- cmd/tcptows/tcptows.go | 68 ++++++------------------------------------ go.mod | 2 +- pkg/iowrappers/io.go | 39 ++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 60 deletions(-) create mode 100644 pkg/iowrappers/io.go diff --git a/cmd/tcptows/tcptows.go b/cmd/tcptows/tcptows.go index dddf8c0..3f36621 100644 --- a/cmd/tcptows/tcptows.go +++ b/cmd/tcptows/tcptows.go @@ -1,19 +1,18 @@ package main import ( - "fmt" + "cidebug/pkg/iowrappers" + "github.com/gorilla/websocket" "io" "log" "net" - - "github.com/gorilla/websocket" ) func handleConnection(tcpConn net.Conn, wsURL string) { defer tcpConn.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - wsConn := NewWebSocketConn(conn) + wsConn := iowrappers.NewWebSocketConn(conn) if err != nil { log.Println("WebSocket connection error:", err) return @@ -26,29 +25,18 @@ func handleConnection(tcpConn net.Conn, wsURL string) { defer func() { waitChannel <- true }() - _, _ = io.Copy(tcpConn, wsConn) - // TODO print error + _, err = io.Copy(tcpConn, wsConn) + if err != nil { + log.Printf("error %v", err) + } }() go func() { defer func() { waitChannel <- true }() - for { - buffer := make([]byte, 1024) - n, err := tcpConn.Read(buffer) - if err != nil { - if err != io.EOF { - log.Println("TCP read error:", err) - } - return - } - err = conn.WriteMessage(websocket.BinaryMessage, buffer[:n]) - if err != nil { - log.Println("WebSocket write error:", err) - return - } - } + _, err = io.Copy(wsConn, tcpConn) + log.Printf("Error %v", err) }() <-waitChannel @@ -76,41 +64,3 @@ func main() { go handleConnection(conn, wsURL) } } - -type WebSocketConn struct { - conn *websocket.Conn - buf []byte -} - -func NewWebSocketConn(conn *websocket.Conn) *WebSocketConn { - return &WebSocketConn{conn: conn} -} - -func (websocketConn *WebSocketConn) Read(p []byte) (n int, err error) { - if len(websocketConn.buf) == 0 { - _, message, err := websocketConn.conn.ReadMessage() - fmt.Println("Got message ", message) - if err != nil { - return 0, err - } - websocketConn.buf = message - } - - n = copy(p, websocketConn.buf) - websocketConn.buf = websocketConn.buf[n:] - - log.Println("Read ", n, " bytes") - return n, err -} - -func (websocketConn *WebSocketConn) Write(p []byte) (n int, err error) { - err = websocketConn.conn.WriteMessage(websocket.BinaryMessage, p) - if err != nil { - n = len(p) - } - return n, err -} - -func (websocketConn *WebSocketConn) Close() error { - return websocketConn.conn.Close() -} diff --git a/go.mod b/go.mod index 5c5eb1b..2f22ceb 100755 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module ssh +module cidebug go 1.18 diff --git a/pkg/iowrappers/io.go b/pkg/iowrappers/io.go new file mode 100644 index 0000000..68ca433 --- /dev/null +++ b/pkg/iowrappers/io.go @@ -0,0 +1,39 @@ +package iowrappers + +import "github.com/gorilla/websocket" + +type WebSocketConn struct { + conn *websocket.Conn + buf []byte +} + +func NewWebSocketConn(conn *websocket.Conn) *WebSocketConn { + return &WebSocketConn{conn: conn} +} + +func (websocketConn *WebSocketConn) Read(p []byte) (n int, err error) { + if len(websocketConn.buf) == 0 { + _, message, err := websocketConn.conn.ReadMessage() + if err != nil { + return 0, err + } + websocketConn.buf = message + } + + n = copy(p, websocketConn.buf) + websocketConn.buf = websocketConn.buf[n:] + + return n, err +} + +func (websocketConn *WebSocketConn) Write(p []byte) (n int, err error) { + err = websocketConn.conn.WriteMessage(websocket.BinaryMessage, p) + if err == nil { + n = len(p) + } + return n, err +} + +func (websocketConn *WebSocketConn) Close() error { + return websocketConn.conn.Close() +}