diff --git a/Makefile b/Makefile index f6cec8a..82d42a9 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ generate: vet: fmt go vet ./... -test: +test: build go test -count=1 -v ./... build: generate vet diff --git a/pkg/support/iowrappers/channelreadwritecloser.go b/pkg/support/iowrappers/channelreadwritecloser.go index e271702..11b312f 100644 --- a/pkg/support/iowrappers/channelreadwritecloser.go +++ b/pkg/support/iowrappers/channelreadwritecloser.go @@ -2,6 +2,7 @@ package iowrappers import ( "context" + "errors" "io" "log" ) @@ -28,17 +29,16 @@ func NewChannelReadWriter(ctx context.Context, receiver <-chan []byte, func (rw *ChannelReadWriter) Read(p []byte) (n int, err error) { nread := copy(p, rw.readBuf) if nread > 0 { + log.Printf("Read %v bytes", nread) rw.readBuf = rw.readBuf[nread:] return nread, nil } select { case <-rw.ctx.Done(): - log.Println("Context was canceled") return 0, io.EOF case data, ok := <-rw.receiver: if !ok { - log.Println("Channel closed") return 0, io.EOF } nread = copy(p, data) @@ -48,13 +48,18 @@ func (rw *ChannelReadWriter) Read(p []byte) (n int, err error) { } func (rw *ChannelReadWriter) Write(p []byte) (n int, err error) { + if rw.closed { + return 0, errors.New("Write on closed channel") + } select { case <-rw.ctx.Done(): + rw.Close() return 0, io.EOF case rw.sender <- p: } return len(p), nil } + func (rw *ChannelReadWriter) Close() error { if !rw.closed { close(rw.sender) diff --git a/pkg/support/iowrappers/channelreadwriter_test.go b/pkg/support/iowrappers/channelreadwriter_test.go index 9cf2585..c1efe7a 100644 --- a/pkg/support/iowrappers/channelreadwriter_test.go +++ b/pkg/support/iowrappers/channelreadwriter_test.go @@ -35,7 +35,6 @@ func (suite *ChannelReadWriterTestSuite) createChannel() { timeoutCancelFunc() cancelFunc() } - suite.cancelFunc = cancelFunc suite.conn = NewChannelReadWriter(ctx, toChannel, fromChannel) } @@ -49,47 +48,77 @@ func (suite *ChannelReadWriterTestSuite) TearDownTest() { type TestFunc func() any -func runAndWait(functions ...TestFunc) []any { +func (suite *ChannelReadWriterTestSuite) runAndWait(functions ...TestFunc) []any { wg := sync.WaitGroup{} wg.Add(len(functions)) res := make([]any, len(functions)) for i, function := range functions { go func() { + defer func() { + wg.Done() + }() res[i] = function() - wg.Done() }() } wg.Wait() return res } -func (suite *ChannelReadWriterTestSuite) Test_SlicesLargeEnough() { - requires := suite.Require() - data := []byte("hello") - - runAndWait( - func() any { - suite.toChannel <- data - log.Println("data sent") - return nil +func (suite *ChannelReadWriterTestSuite) Test_SuccessfulCommunication() { + tests := []struct { + name string + data string + chunkSizes []int + chunks []string + }{ + { + name: "buffer_large_enough", + data: "hello", + chunkSizes: []int{10}, + chunks: []string{"hello"}, }, - func() any { - buf := make([]byte, len(data)*2) - n, err := suite.conn.Read(buf) - requires.Nil(err) - requires.Equal(n, len(data)) - requires.Equal(data, buf[:n]) - return nil + { + name: "two_reads_required", + data: "hello", + chunkSizes: []int{3, 10}, + chunks: []string{"hel", "lo"}, }, - ) -} + { + name: "many_reads_required", + data: "hello", + chunkSizes: []int{1, 1, 1, 1, 1}, + chunks: []string{"h", "e", "l", "l", "o"}, + }, + } -func (suite *ChannelReadWriterTestSuite) Test_SliceTooSmallFullReadInTwoParts() { - suite.FailNow("todo") -} + for _, test := range tests { + suite.Run(test.name, func() { + suite.runAndWait( + func() any { + select { + case <-suite.ctx.Done(): + suite.FailNow("deadline reached") + log.Println("Write deadline exceeded") + case suite.toChannel <- []byte(test.data): + } + return nil + }, + func() any { + remainder := test.data + for i, chunkSize := range test.chunkSizes { + buf := make([]byte, chunkSize) + n, err := suite.conn.Read(buf) + suite.Nil(err) + suite.Equal(n, len(test.chunks[i])) + suite.Equal([]byte(remainder[:n]), buf[:n]) + remainder = remainder[n:] + } + return nil + }, + ) + }) + } -func (suite *ChannelReadWriterTestSuite) Test_SliceTooSmallFullREadInManyParts() { - suite.FailNow("todo") } func (suite *ChannelReadWriterTestSuite) Test_Close() {