package websocketutil

import (
	"github.com/gorilla/websocket"
	"net"
	"time"
)

type WebSocketConn struct {
	conn *websocket.Conn
	buf  []byte
}

func (websocketConn *WebSocketConn) Read(p []byte) (n int, err error) {
	if len(websocketConn.buf) == 0 {
		_, message, err := websocketConn.conn.ReadMessage()
		if err != nil {
			return 0, err
		}
		websocketConn.buf = message
	}

	n = copy(p, websocketConn.buf)
	websocketConn.buf = websocketConn.buf[n:]

	return n, err
}

func NewWebSocketConn(conn *websocket.Conn) *WebSocketConn {
	return &WebSocketConn{conn: conn}
}

func (websocketConn *WebSocketConn) Write(p []byte) (n int, err error) {
	err = websocketConn.conn.WriteMessage(websocket.BinaryMessage, p)
	if err == nil {
		n = len(p)
	}
	return n, err
}

func (websocketConn *WebSocketConn) Close() error {
	return websocketConn.conn.Close()
}

func (WebSocketConn *WebSocketConn) LocalAddr() net.Addr {
	return WebSocketConn.conn.LocalAddr()
}

func (websocketConn *WebSocketConn) RemoteAddr() net.Addr {
	return websocketConn.conn.RemoteAddr()
}

func (websocketConn *WebSocketConn) SetDeadline(t time.Time) error {
	return nil
}

func (websocketConn *WebSocketConn) SetReadDeadline(t time.Time) error {
	return nil
}

func (websocketConn *WebSocketConn) SetWriteDeadline(t time.Time) error {
	return nil
}

func ConnectWebSocket(conn net.Conn, urlStr string) (net.Conn, error) {
	dialer := *websocket.DefaultDialer
	dialer.NetDial = func(network, addr string) (net.Conn, error) {
		return conn, nil
	}

	wsConn, _, err := dialer.Dial(urlStr, nil)
	if err != nil {
		return nil, err
	}

	return NewWebSocketConn(wsConn), nil
}