package testsupport import ( "context" "github.com/stretchr/testify/suite" "io" "log" "net/http" "strings" "testing" "time" ) type ChannelReadWriterTestSuite struct { suite.Suite pprofServer *http.Server ctx context.Context cancelFunc context.CancelFunc toChannel chan<- []byte fromChannel <-chan []byte conn *ChannelReadWriteCloser } func TestChannelReadWriterSuite(t *testing.T) { suite.Run(t, &ChannelReadWriterTestSuite{}) } func (s *ChannelReadWriterTestSuite) createChannel() { // buffered channels provide more similar behavior to TCP connections // then unbuffered channels. toChannel := make(chan []byte, 10) fromChannel := make(chan []byte, 10) s.toChannel = toChannel s.fromChannel = fromChannel ctx, cancelFunc := CreateTestContext(context.Background(), 10*time.Second) s.ctx = ctx s.cancelFunc = cancelFunc s.conn = NewChannelReadWriteCloser(ctx, toChannel, fromChannel) } func (s *ChannelReadWriterTestSuite) SetupSuite() { s.pprofServer = StartPprof("") } func (s *ChannelReadWriterTestSuite) TearDownSuite() { StopPprof(s.ctx, s.pprofServer) } func (s *ChannelReadWriterTestSuite) SetupTest() { s.createChannel() } func (s *ChannelReadWriterTestSuite) TearDownTest() { s.cancelFunc() } func (s *ChannelReadWriterTestSuite) Test_SuccessfulChannelToReadWriter() { tests := []SuccessfulTest{ { name: "buffer_large_enough", data: []string{"hello"}, chunkSizes: []int{10}, chunks: []string{"hello"}, }, { name: "two_reads_required", data: []string{"hello"}, chunkSizes: []int{3, 10}, chunks: []string{"hel", "lo"}, }, { name: "many_reads_required", data: []string{"hello"}, chunkSizes: []int{1, 1, 1, 1, 1}, chunks: []string{"h", "e", "l", "l", "o"}, }, { name: "buffer_large_enough_multiple_writes", data: []string{"hel", "lo"}, chunkSizes: []int{3, 2}, chunks: []string{"hel", "lo"}, }, { // NOTE: no intelligence in the reader to fill up the read buffer when it is not full // therefore, the second read will have only 1 char since the first channel read returned // 3 of which 2 where returned in the first read call to the ChannelReadWriteCloser. name: "buffer_too_small_multiple_writes", data: []string{"hel", "lo"}, chunkSizes: []int{2, 2, 2}, chunks: []string{"he", "l", "lo"}, }, } for _, test := range tests { s.Run(test.name, s.runSuccessfulChannelToReadWrite(test)) } } func (s *ChannelReadWriterTestSuite) Test_SuccessfulReadWriterToChannel() { tests := []SuccessfulTest{ { name: "buffer_large_enough", data: []string{"hello"}, chunkSizes: []int{1}, }, { name: "two_reads_required", data: []string{"hel", "lo"}, chunkSizes: []int{3, 2}, }, { name: "many_reads_required", data: []string{"h", "e", "l", "l", "o"}, chunkSizes: []int{1, 1, 1, 1, 1}, }, } for _, test := range tests { s.Run(test.name, s.runSuccessfulReadWriteToChannel(test)) } } func (s *ChannelReadWriterTestSuite) runSuccessfulReadWriteToChannel(test SuccessfulTest) func() { return func() { RunAndWait( &s.Suite, func() any { for _, d := range test.data { n, err := s.conn.Write([]byte(d)) s.Nil(err) s.Equal(len(d), n) } return nil }, func() any { for _, chunk := range test.data { select { case <-s.ctx.Done(): s.Fail("context canceled") case d, ok := <-s.fromChannel: s.True(ok) s.Equal([]byte(chunk), d) } } return nil }, ) } } type SuccessfulTest struct { name string data []string chunkSizes []int chunks []string } func (s *ChannelReadWriterTestSuite) runSuccessfulChannelToReadWrite(test SuccessfulTest) func() { return func() { RunAndWait( &s.Suite, func() any { for _, d := range test.data { select { case <-s.ctx.Done(): s.FailNow("deadline reached") log.Println("Write deadline exceeded") case s.toChannel <- []byte(d): } } return nil }, func() any { remainder := strings.Join(test.data, "") for i, chunkSize := range test.chunkSizes { buf := make([]byte, chunkSize) n, err := s.conn.Read(buf) s.Nil(err) s.Equal(n, len(test.chunks[i])) s.Equal([]byte(remainder[:n]), buf[:n]) remainder = remainder[n:] } return nil }, ) } } func (s *ChannelReadWriterTestSuite) Test_ChannelCloseBeforeRead() { data := "hello" // buffered channel s.toChannel <- []byte(data) close(s.toChannel) buf := make([]byte, len(data)) n, err := s.conn.Read(buf) s.Nil(err) s.Equal(len(data), n) s.Equal([]byte(data), buf[:n]) } func (s *ChannelReadWriterTestSuite) Test_ChannelCloseAfterWrite() { data := "hello" n, err := s.conn.Write([]byte(data)) s.Nil(err) s.Equal(len(data), n) err = s.conn.Close() s.Nil(err) select { case <-s.ctx.Done(): s.Fail("channel closed") case d, ok := <-s.fromChannel: s.True(ok) s.Equal([]byte(data), d) } n, err = s.conn.Write([]byte(data)) s.NotNil(err) s.Equal(0, n) } func (s *ChannelReadWriterTestSuite) Test_CloseTwice() { err := s.conn.Close() s.Nil(err) s.True(s.conn.closed) err = s.conn.Close() s.Nil(err) } func (s *ChannelReadWriterTestSuite) Test_ContextCanceledRead() { s.cancelFunc() buf := make([]byte, 100) n, err := s.conn.Read(buf) s.Equal(io.ErrClosedPipe, err) s.Equal(0, n) } func (s *ChannelReadWriterTestSuite) Test_ContextCanceledWrite() { s.cancelFunc() n, err := s.conn.Write([]byte("hello")) s.Equal(io.ErrClosedPipe, err) s.Equal(0, n) }