diff --git a/pkg/comms/agentserver.go b/pkg/comms/agentserver.go index 924d093..6a9eeed 100644 --- a/pkg/comms/agentserver.go +++ b/pkg/comms/agentserver.go @@ -213,7 +213,7 @@ func CheckProtocolVersion(role Role, channel GOBChannel) error { } return nil default: - panic(fmt.Errorf("unexpected rolg %v", role)) + panic(fmt.Errorf("unexpected role %v", role)) } } diff --git a/pkg/comms/agentserver_test.go b/pkg/comms/agentserver_test.go index 3fbc8b2..b395104 100644 --- a/pkg/comms/agentserver_test.go +++ b/pkg/comms/agentserver_test.go @@ -5,15 +5,41 @@ import ( "git.wamblee.org/converge/pkg/testsupport" "github.com/stretchr/testify/suite" "log" - "sync" + "net" + "net/http" "testing" + "time" ) type AgentServerTestSuite struct { suite.Suite + + ctx context.Context + cancelFunc context.CancelFunc + pprofServer *http.Server + + agentConnection net.Conn + serverConnection net.Conn } -func (suite *AgentServerTestSuite) SetupTest() { +func (s *AgentServerTestSuite) SetupSuite() { + s.pprofServer = testsupport.StartPprof("") +} + +func (s *AgentServerTestSuite) TearDownSuite() { + testsupport.StopPprof(s.ctx, s.pprofServer) +} + +func (s *AgentServerTestSuite) SetupTest() { + ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second) + s.ctx = ctx + s.cancelFunc = cancelFunc + serverConnection, agentConnection := net.Pipe() + deadline := time.Now().Add(10 * time.Second) + serverConnection.SetDeadline(deadline) + agentConnection.SetDeadline(deadline) + s.serverConnection = serverConnection + s.agentConnection = agentConnection } func (suite *AgentServerTestSuite) TearDownTest() { @@ -23,34 +49,46 @@ func TestAgentServerTestSuite(t *testing.T) { suite.Run(t, &AgentServerTestSuite{}) } -func (suite *AgentServerTestSuite) TestNewCommChannel() { - bitpipe := testsupport.NewInmemoryConnection(context.Background(), "inmemory") - agentConnection := bitpipe.Front() - serverConnection := bitpipe.Back() - requires := suite.Require() +func (s *AgentServerTestSuite) TestNewCommChannel() { - wg := sync.WaitGroup{} - wg.Add(2) - go func() { - log.Println("Agent initializing") - commChannel, err := NewCommChannel(Agent, agentConnection) - requires.Nil(err) - protocolVersion := ProtocolVersion{Version: 10} - err = SendWithTimeout[ProtocolVersion](commChannel.SideChannel, protocolVersion) - requires.Nil(err) - log.Printf("Sent one message %v", protocolVersion) - wg.Done() - }() + // Setup Comm channel + commChannels := testsupport.RunAndWait( + &s.Suite, + func() any { + log.Println("Agent initializing") + commChannel, err := NewCommChannel(Agent, s.agentConnection) + s.Nil(err) + return commChannel + }, + func() any { + log.Println("Server initializing") + commChannel, err := NewCommChannel(ConvergeServer, s.serverConnection) + s.Nil(err) + return commChannel + }, + ) - go func() { - log.Println("Server initializing") - commChannel, err := NewCommChannel(ConvergeServer, serverConnection) - requires.Nil(err) - protocolVersion, err := ReceiveWithTimeout[ProtocolVersion](commChannel.SideChannel) - requires.Nil(err) - log.Printf("Received one message %v", protocolVersion) - wg.Done() - }() - wg.Wait() + s.Equal(2, len(commChannels)) + agentCommChannel := commChannels[0].(CommChannel) + serverCommChannel := commChannels[1].(CommChannel) + // verify the side channel is working by sending an object + testsupport.RunAndWait( + &s.Suite, + func() any { + protocolVersion := ProtocolVersion{Version: 10} + err := SendWithTimeout[ProtocolVersion](agentCommChannel.SideChannel, protocolVersion) + s.Nil(err) + log.Printf("Sent one message %v", protocolVersion) + return nil + }, + func() any { + protocolVersion, err := ReceiveWithTimeout[ProtocolVersion](serverCommChannel.SideChannel) + s.Nil(err) + log.Printf("Received one message %v", protocolVersion) + return nil + }, + ) + + log.Printf("%v %v", agentCommChannel, serverCommChannel) } diff --git a/pkg/comms/gobchannel.go b/pkg/comms/gobchannel.go index 854112c..d5f0f56 100644 --- a/pkg/comms/gobchannel.go +++ b/pkg/comms/gobchannel.go @@ -70,7 +70,7 @@ func SendWithTimeout[T any](channel GOBChannel, obj T) error { SendAsync(channel, obj, done, errors) select { case <-time.After(MESSAGE_TIMEOUT): - return fmt.Errorf("Timeout in SwndWithTimout") + return fmt.Errorf("Timeout in SendWithTimout") case err := <-errors: return err case <-done: diff --git a/pkg/testsupport/channelreadwritecloser.go b/pkg/testsupport/channelreadwriter.go similarity index 63% rename from pkg/testsupport/channelreadwritecloser.go rename to pkg/testsupport/channelreadwriter.go index c5e9d43..51193f9 100644 --- a/pkg/testsupport/channelreadwritecloser.go +++ b/pkg/testsupport/channelreadwriter.go @@ -5,28 +5,39 @@ import ( "errors" "io" "log" + "runtime" + "sync" ) type ChannelReadWriter struct { - ctx context.Context - receiver <-chan []byte + ctx context.Context + + receiverMutex sync.Mutex + receiver <-chan []byte // bytes that were read and that did not fit readBuf []byte - sender chan<- []byte - closed bool + + senderMutex sync.Mutex + sender chan<- []byte + closed bool } func NewChannelReadWriter(ctx context.Context, receiver <-chan []byte, sender chan<- []byte) *ChannelReadWriter { return &ChannelReadWriter{ - ctx: ctx, - receiver: receiver, - sender: sender, - closed: false, + ctx: ctx, + receiverMutex: sync.Mutex{}, + receiver: receiver, + senderMutex: sync.Mutex{}, + sender: sender, + closed: false, } } func (rw *ChannelReadWriter) Read(p []byte) (n int, err error) { + rw.receiverMutex.Lock() + defer rw.receiverMutex.Unlock() + nread := copy(p, rw.readBuf) if nread > 0 { log.Printf("Read %v bytes", nread) @@ -39,15 +50,18 @@ func (rw *ChannelReadWriter) Read(p []byte) (n int, err error) { return 0, io.ErrClosedPipe case data, ok := <-rw.receiver: if !ok { - return 0, io.EOF + return 0, errors.New("ladida") //io.EOF } nread = copy(p, data) rw.readBuf = data[nread:] return nread, nil } - } + func (rw *ChannelReadWriter) Write(p []byte) (n int, err error) { + rw.senderMutex.Lock() + defer rw.senderMutex.Unlock() + if rw.closed { return 0, errors.New("Write on closed channel") } @@ -55,7 +69,7 @@ func (rw *ChannelReadWriter) Write(p []byte) (n int, err error) { // if context is canceled it should never write select { case <-rw.ctx.Done(): - rw.Close() + //rw.Close() return 0, io.ErrClosedPipe default: } @@ -63,7 +77,7 @@ func (rw *ChannelReadWriter) Write(p []byte) (n int, err error) { select { // deal with closing duirng the write case <-rw.ctx.Done(): - rw.Close() + //rw.Close() return 0, io.ErrClosedPipe case rw.sender <- p: } @@ -71,7 +85,14 @@ func (rw *ChannelReadWriter) Write(p []byte) (n int, err error) { } func (rw *ChannelReadWriter) Close() error { + rw.senderMutex.Lock() + defer rw.senderMutex.Unlock() + if !rw.closed { + log.Println("Closing ChannelReadWriter") + buf := make([]byte, 0) + runtime.Stack(buf, false) + log.Printf("Stack %v", string(buf)) close(rw.sender) rw.closed = true } diff --git a/pkg/testsupport/channelreadwriter_test.go b/pkg/testsupport/channelreadwriter_test.go index f22376e..fca43af 100644 --- a/pkg/testsupport/channelreadwriter_test.go +++ b/pkg/testsupport/channelreadwriter_test.go @@ -33,22 +33,18 @@ func (s *ChannelReadWriterTestSuite) createChannel() { fromChannel := make(chan []byte, 10) s.toChannel = toChannel s.fromChannel = fromChannel - ctx, cancelFunc := context.WithCancel(context.Background()) - ctx, timeoutCancelFunc := context.WithTimeout(ctx, 10*time.Second) + ctx, cancelFunc := CreateTestContext(context.Background(), 10*time.Second) s.ctx = ctx - s.cancelFunc = func() { - timeoutCancelFunc() - cancelFunc() - } + s.cancelFunc = cancelFunc s.conn = NewChannelReadWriter(ctx, toChannel, fromChannel) } func (s *ChannelReadWriterTestSuite) SetupSuite() { - s.pprofServer = startPprof("") + s.pprofServer = StartPprof("") } func (s *ChannelReadWriterTestSuite) TearDownSuite() { - stopPprof(s.ctx, s.pprofServer) + StopPprof(s.ctx, s.pprofServer) } func (s *ChannelReadWriterTestSuite) SetupTest() { @@ -127,7 +123,7 @@ func (s *ChannelReadWriterTestSuite) Test_SuccessfulReadWriterToChannel() { func (s *ChannelReadWriterTestSuite) runSuccessfulReadWriteToChannel(test SuccessfulTest) func() { return func() { - runAndWait( + RunAndWait( &s.Suite, func() any { for _, d := range test.data { @@ -163,7 +159,7 @@ type SuccessfulTest struct { func (s *ChannelReadWriterTestSuite) runSuccessfulChannelToReadWrite(test SuccessfulTest) func() { return func() { - runAndWait( + RunAndWait( &s.Suite, func() any { for _, d := range test.data { diff --git a/pkg/testsupport/inmemoryconnection.go b/pkg/testsupport/inmemoryconnection.go index 2739f0e..acf50a1 100644 --- a/pkg/testsupport/inmemoryconnection.go +++ b/pkg/testsupport/inmemoryconnection.go @@ -12,9 +12,9 @@ type InmemoryConnection struct { func NewInmemoryConnection(ctx context.Context, addr string) *InmemoryConnection { pipe := InmemoryConnection{ ctx: ctx, - // arbitrary unbuffered channel, unbuffered is more similar to TCP connections. - frontToBack: make(chan []byte), - backToFront: make(chan []byte), + // TODO: somehow does not work with unbuffered channel and yamux + frontToBack: make(chan []byte, 0), + backToFront: make(chan []byte, 0), addr: addr, } return &pipe diff --git a/pkg/testsupport/inmemoryconnection_test.go b/pkg/testsupport/inmemoryconnection_test.go index 186d85f..90d3580 100644 --- a/pkg/testsupport/inmemoryconnection_test.go +++ b/pkg/testsupport/inmemoryconnection_test.go @@ -1 +1,45 @@ package testsupport + +import ( + "context" + "github.com/stretchr/testify/suite" + "net/http" + "testing" + "time" +) + +type InMemoryTestSuite struct { + suite.Suite + + pprofServer *http.Server + ctx context.Context + cancelFunc context.CancelFunc + pipe *InmemoryConnection +} + +func TestInMemoryConnectionTestSuite(t *testing.T) { + suite.Run(t, &InMemoryTestSuite{}) +} + +func (s *InMemoryTestSuite) createConnection() { + ctx, cancelFunc := CreateTestContext(context.Background(), 10*time.Second) + s.ctx = ctx + s.cancelFunc = cancelFunc + s.pipe = NewInmemoryConnection(ctx, "inmemory") +} + +func (s *InMemoryTestSuite) SetupSuite() { + s.pprofServer = StartPprof("") +} + +func (s *InMemoryTestSuite) TearDownSuite() { + StopPprof(s.ctx, s.pprofServer) +} + +func (s *InMemoryTestSuite) SetupTest() { + s.createConnection() +} + +func (s *InMemoryTestSuite) TearDownTest() { + s.cancelFunc() +} diff --git a/pkg/testsupport/utils.go b/pkg/testsupport/utils.go index 7be0926..29787d1 100644 --- a/pkg/testsupport/utils.go +++ b/pkg/testsupport/utils.go @@ -8,11 +8,12 @@ import ( "net/http" _ "runtime/pprof" "sync" + "time" ) type TestFunction func() any -func runAndWait(suite *suite.Suite, functions ...TestFunction) []any { +func RunAndWait(suite *suite.Suite, functions ...TestFunction) []any { wg := sync.WaitGroup{} wg.Add(len(functions)) res := make([]any, len(functions)) @@ -28,7 +29,7 @@ func runAndWait(suite *suite.Suite, functions ...TestFunction) []any { return res } -func startPprof(port string) *http.Server { +func StartPprof(port string) *http.Server { if port == "" { port = ":9000" } @@ -47,7 +48,7 @@ func startPprof(port string) *http.Server { return &srv } -func stopPprof(ctx context.Context, server *http.Server) { +func StopPprof(ctx context.Context, server *http.Server) { err := server.Shutdown(ctx) if err != nil { log.Println("Error shutting down test pprof server") @@ -55,3 +56,13 @@ func stopPprof(ctx context.Context, server *http.Server) { } log.Println("Test pprof server stopped") } + +func CreateTestContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + ctx, cancelFunc := context.WithCancel(ctx) + ctx, timeoutCancelFunc := context.WithTimeout(ctx, timeout) + compositeCancelFunc := func() { + timeoutCancelFunc() + cancelFunc() + } + return ctx, compositeCancelFunc +}