package testsupport

import (
	"context"
	"errors"
	"io"
	"sync"
)

type ChannelReadWriteCloser struct {
	ctx context.Context

	receiverMutex sync.Mutex
	receiver      <-chan []byte
	// bytes that were read and that did not fit
	readBuf []byte

	senderMutex sync.Mutex
	sender      chan<- []byte
	closed      bool
}

func NewChannelReadWriteCloser(ctx context.Context, receiver <-chan []byte,
	sender chan<- []byte) *ChannelReadWriteCloser {
	return &ChannelReadWriteCloser{
		ctx:           ctx,
		receiverMutex: sync.Mutex{},
		receiver:      receiver,
		senderMutex:   sync.Mutex{},
		sender:        sender,
		closed:        false,
	}
}

func (rw *ChannelReadWriteCloser) Read(p []byte) (n int, err error) {
	rw.receiverMutex.Lock()
	defer rw.receiverMutex.Unlock()

	nread := copy(p, rw.readBuf)
	if nread > 0 {
		rw.readBuf = rw.readBuf[nread:]
		return nread, nil
	}

	select {
	case <-rw.ctx.Done():
		return 0, io.ErrClosedPipe
	case data, ok := <-rw.receiver:
		if !ok {
			return 0, io.EOF
		}
		nread = copy(p, data)
		rw.readBuf = data[nread:]
		return nread, nil
	}
}

func (rw *ChannelReadWriteCloser) Write(pIn []byte) (n int, err error) {
	p := make([]byte, len(pIn), len(pIn))
	copy(p, pIn)
	rw.senderMutex.Lock()
	defer rw.senderMutex.Unlock()

	if rw.closed {
		return 0, errors.New("Write on closed channel")
	}

	// if context is canceled it should never write
	select {
	case <-rw.ctx.Done():
		//rw.Close()
		return 0, io.ErrClosedPipe
	default:
	}

	select {
	// deal with closing duirng the write
	case <-rw.ctx.Done():
		//rw.Close()
		return 0, io.ErrClosedPipe
	case rw.sender <- p:
	}
	return len(p), nil
}

func (rw *ChannelReadWriteCloser) Close() error {
	rw.senderMutex.Lock()
	defer rw.senderMutex.Unlock()

	if !rw.closed {
		close(rw.sender)
		rw.closed = true
	}
	return nil
}