generalizing websocket connection to reader/writer.
This commit is contained in:
		
							parent
							
								
									4eb83d033a
								
							
						
					
					
						commit
						46a7588896
					
				| @ -1,6 +1,7 @@ | ||||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net" | ||||
| @ -8,46 +9,49 @@ import ( | ||||
| 	"github.com/gorilla/websocket" | ||||
| ) | ||||
| 
 | ||||
| func handleConnection(conn net.Conn, wsURL string) { | ||||
| 	defer conn.Close() | ||||
| func handleConnection(tcpConn net.Conn, wsURL string) { | ||||
| 	defer tcpConn.Close() | ||||
| 
 | ||||
| 	wsConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) | ||||
| 	conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) | ||||
| 	wsConn := NewWebSocketConn(conn) | ||||
| 	if err != nil { | ||||
| 		log.Println("WebSocket connection error:", err) | ||||
| 		return | ||||
| 	} | ||||
| 	defer wsConn.Close() | ||||
| 
 | ||||
| 	waitChannel := make(chan bool) | ||||
| 
 | ||||
| 	go func() { | ||||
| 		defer func() { | ||||
| 			waitChannel <- true | ||||
| 		}() | ||||
| 		_, _ = io.Copy(tcpConn, wsConn) | ||||
| 		// TODO print error
 | ||||
| 	}() | ||||
| 
 | ||||
| 	go func() { | ||||
| 		defer func() { | ||||
| 			waitChannel <- true | ||||
| 		}() | ||||
| 		for { | ||||
| 			_, message, err := wsConn.ReadMessage() | ||||
| 			buffer := make([]byte, 1024) | ||||
| 			n, err := tcpConn.Read(buffer) | ||||
| 			if err != nil { | ||||
| 				log.Println("WebSocket read error:", err) | ||||
| 				if err != io.EOF { | ||||
| 					log.Println("TCP read error:", err) | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
| 			_, err = conn.Write(message) | ||||
| 			err = conn.WriteMessage(websocket.BinaryMessage, buffer[:n]) | ||||
| 			if err != nil { | ||||
| 				log.Println("TCP write error:", err) | ||||
| 				log.Println("WebSocket write error:", err) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	for { | ||||
| 		buffer := make([]byte, 1024) | ||||
| 		n, err := conn.Read(buffer) | ||||
| 		if err != nil { | ||||
| 			if err != io.EOF { | ||||
| 				log.Println("TCP read error:", err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
| 		err = wsConn.WriteMessage(websocket.BinaryMessage, buffer[:n]) | ||||
| 		if err != nil { | ||||
| 			log.Println("WebSocket write error:", err) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	<-waitChannel | ||||
| } | ||||
| 
 | ||||
| func main() { | ||||
| @ -72,3 +76,41 @@ func main() { | ||||
| 		go handleConnection(conn, wsURL) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type WebSocketConn struct { | ||||
| 	conn *websocket.Conn | ||||
| 	buf  []byte | ||||
| } | ||||
| 
 | ||||
| func NewWebSocketConn(conn *websocket.Conn) *WebSocketConn { | ||||
| 	return &WebSocketConn{conn: conn} | ||||
| } | ||||
| 
 | ||||
| func (websocketConn *WebSocketConn) Read(p []byte) (n int, err error) { | ||||
| 	if len(websocketConn.buf) == 0 { | ||||
| 		_, message, err := websocketConn.conn.ReadMessage() | ||||
| 		fmt.Println("Got message ", message) | ||||
| 		if err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 		websocketConn.buf = message | ||||
| 	} | ||||
| 
 | ||||
| 	n = copy(p, websocketConn.buf) | ||||
| 	websocketConn.buf = websocketConn.buf[n:] | ||||
| 
 | ||||
| 	log.Println("Read ", n, " bytes") | ||||
| 	return n, err | ||||
| } | ||||
| 
 | ||||
| func (websocketConn *WebSocketConn) Write(p []byte) (n int, err error) { | ||||
| 	err = websocketConn.conn.WriteMessage(websocket.BinaryMessage, p) | ||||
| 	if err != nil { | ||||
| 		n = len(p) | ||||
| 	} | ||||
| 	return n, err | ||||
| } | ||||
| 
 | ||||
| func (websocketConn *WebSocketConn) Close() error { | ||||
| 	return websocketConn.conn.Close() | ||||
| } | ||||
|  | ||||
| @ -2,11 +2,10 @@ package main | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/gorilla/websocket" | ||||
| ) | ||||
| 
 | ||||
| var upgrader = websocket.Upgrader{ | ||||
| @ -14,26 +13,25 @@ var upgrader = websocket.Upgrader{ | ||||
| 	WriteBufferSize: 1024, | ||||
| } | ||||
| 
 | ||||
| var tcpConn net.Conn | ||||
| 
 | ||||
| func main() { | ||||
| 	// Connect to TCP server
 | ||||
| 	var err error | ||||
| 	tcpConn, err = net.Dial("tcp", "localhost:2222") // Replace with your TCP server address
 | ||||
| 	if err != nil { | ||||
| 		log.Fatal("Error connecting to TCP server:", err) | ||||
| 	} | ||||
| 	defer tcpConn.Close() | ||||
| 
 | ||||
| 	// Set up WebSocket handler
 | ||||
| 	http.HandleFunc("/ws", handleWebSocket) | ||||
| 	http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		tcpConn, err := net.Dial("tcp", "localhost:2222") // Replace with your TCP server address
 | ||||
| 		if err != nil { | ||||
| 			log.Fatal("Error connecting to TCP server:", err) | ||||
| 		} | ||||
| 		defer tcpConn.Close() | ||||
| 		handleWebSocket(w, r, tcpConn) | ||||
| 	}) | ||||
| 
 | ||||
| 	// Start HTTP server
 | ||||
| 	fmt.Println("WebSocket server listening on :8000") | ||||
| 	log.Fatal(http.ListenAndServe(":8000", nil)) | ||||
| } | ||||
| 
 | ||||
| func handleWebSocket(w http.ResponseWriter, r *http.Request) { | ||||
| func handleWebSocket(w http.ResponseWriter, r *http.Request, tcpConn net.Conn) { | ||||
| 	conn, err := upgrader.Upgrade(w, r, nil) | ||||
| 	if err != nil { | ||||
| 		log.Println("Error upgrading to WebSocket:", err) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user