discovered net.Pipe for testing tcp connnections which makes the

previously developed ChannelReadWriter and InmemoryConnection obsolete.
This commit is contained in:
Erik Brakkee 2024-08-20 11:28:09 +02:00
parent c55c4aa365
commit ed04ac035a
8 changed files with 169 additions and 59 deletions

View File

@ -213,7 +213,7 @@ func CheckProtocolVersion(role Role, channel GOBChannel) error {
} }
return nil return nil
default: default:
panic(fmt.Errorf("unexpected rolg %v", role)) panic(fmt.Errorf("unexpected role %v", role))
} }
} }

View File

@ -5,15 +5,41 @@ import (
"git.wamblee.org/converge/pkg/testsupport" "git.wamblee.org/converge/pkg/testsupport"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"log" "log"
"sync" "net"
"net/http"
"testing" "testing"
"time"
) )
type AgentServerTestSuite struct { type AgentServerTestSuite struct {
suite.Suite suite.Suite
ctx context.Context
cancelFunc context.CancelFunc
pprofServer *http.Server
agentConnection net.Conn
serverConnection net.Conn
} }
func (suite *AgentServerTestSuite) SetupTest() { func (s *AgentServerTestSuite) SetupSuite() {
s.pprofServer = testsupport.StartPprof("")
}
func (s *AgentServerTestSuite) TearDownSuite() {
testsupport.StopPprof(s.ctx, s.pprofServer)
}
func (s *AgentServerTestSuite) SetupTest() {
ctx, cancelFunc := testsupport.CreateTestContext(context.Background(), 10*time.Second)
s.ctx = ctx
s.cancelFunc = cancelFunc
serverConnection, agentConnection := net.Pipe()
deadline := time.Now().Add(10 * time.Second)
serverConnection.SetDeadline(deadline)
agentConnection.SetDeadline(deadline)
s.serverConnection = serverConnection
s.agentConnection = agentConnection
} }
func (suite *AgentServerTestSuite) TearDownTest() { func (suite *AgentServerTestSuite) TearDownTest() {
@ -23,34 +49,46 @@ func TestAgentServerTestSuite(t *testing.T) {
suite.Run(t, &AgentServerTestSuite{}) suite.Run(t, &AgentServerTestSuite{})
} }
func (suite *AgentServerTestSuite) TestNewCommChannel() { func (s *AgentServerTestSuite) TestNewCommChannel() {
bitpipe := testsupport.NewInmemoryConnection(context.Background(), "inmemory")
agentConnection := bitpipe.Front()
serverConnection := bitpipe.Back()
requires := suite.Require()
wg := sync.WaitGroup{} // Setup Comm channel
wg.Add(2) commChannels := testsupport.RunAndWait(
go func() { &s.Suite,
log.Println("Agent initializing") func() any {
commChannel, err := NewCommChannel(Agent, agentConnection) log.Println("Agent initializing")
requires.Nil(err) commChannel, err := NewCommChannel(Agent, s.agentConnection)
protocolVersion := ProtocolVersion{Version: 10} s.Nil(err)
err = SendWithTimeout[ProtocolVersion](commChannel.SideChannel, protocolVersion) return commChannel
requires.Nil(err) },
log.Printf("Sent one message %v", protocolVersion) func() any {
wg.Done() log.Println("Server initializing")
}() commChannel, err := NewCommChannel(ConvergeServer, s.serverConnection)
s.Nil(err)
return commChannel
},
)
go func() { s.Equal(2, len(commChannels))
log.Println("Server initializing") agentCommChannel := commChannels[0].(CommChannel)
commChannel, err := NewCommChannel(ConvergeServer, serverConnection) serverCommChannel := commChannels[1].(CommChannel)
requires.Nil(err)
protocolVersion, err := ReceiveWithTimeout[ProtocolVersion](commChannel.SideChannel)
requires.Nil(err)
log.Printf("Received one message %v", protocolVersion)
wg.Done()
}()
wg.Wait()
// verify the side channel is working by sending an object
testsupport.RunAndWait(
&s.Suite,
func() any {
protocolVersion := ProtocolVersion{Version: 10}
err := SendWithTimeout[ProtocolVersion](agentCommChannel.SideChannel, protocolVersion)
s.Nil(err)
log.Printf("Sent one message %v", protocolVersion)
return nil
},
func() any {
protocolVersion, err := ReceiveWithTimeout[ProtocolVersion](serverCommChannel.SideChannel)
s.Nil(err)
log.Printf("Received one message %v", protocolVersion)
return nil
},
)
log.Printf("%v %v", agentCommChannel, serverCommChannel)
} }

View File

@ -70,7 +70,7 @@ func SendWithTimeout[T any](channel GOBChannel, obj T) error {
SendAsync(channel, obj, done, errors) SendAsync(channel, obj, done, errors)
select { select {
case <-time.After(MESSAGE_TIMEOUT): case <-time.After(MESSAGE_TIMEOUT):
return fmt.Errorf("Timeout in SwndWithTimout") return fmt.Errorf("Timeout in SendWithTimout")
case err := <-errors: case err := <-errors:
return err return err
case <-done: case <-done:

View File

@ -5,28 +5,39 @@ import (
"errors" "errors"
"io" "io"
"log" "log"
"runtime"
"sync"
) )
type ChannelReadWriter struct { type ChannelReadWriter struct {
ctx context.Context ctx context.Context
receiver <-chan []byte
receiverMutex sync.Mutex
receiver <-chan []byte
// bytes that were read and that did not fit // bytes that were read and that did not fit
readBuf []byte readBuf []byte
sender chan<- []byte
closed bool senderMutex sync.Mutex
sender chan<- []byte
closed bool
} }
func NewChannelReadWriter(ctx context.Context, receiver <-chan []byte, func NewChannelReadWriter(ctx context.Context, receiver <-chan []byte,
sender chan<- []byte) *ChannelReadWriter { sender chan<- []byte) *ChannelReadWriter {
return &ChannelReadWriter{ return &ChannelReadWriter{
ctx: ctx, ctx: ctx,
receiver: receiver, receiverMutex: sync.Mutex{},
sender: sender, receiver: receiver,
closed: false, senderMutex: sync.Mutex{},
sender: sender,
closed: false,
} }
} }
func (rw *ChannelReadWriter) Read(p []byte) (n int, err error) { func (rw *ChannelReadWriter) Read(p []byte) (n int, err error) {
rw.receiverMutex.Lock()
defer rw.receiverMutex.Unlock()
nread := copy(p, rw.readBuf) nread := copy(p, rw.readBuf)
if nread > 0 { if nread > 0 {
log.Printf("Read %v bytes", nread) log.Printf("Read %v bytes", nread)
@ -39,15 +50,18 @@ func (rw *ChannelReadWriter) Read(p []byte) (n int, err error) {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
case data, ok := <-rw.receiver: case data, ok := <-rw.receiver:
if !ok { if !ok {
return 0, io.EOF return 0, errors.New("ladida") //io.EOF
} }
nread = copy(p, data) nread = copy(p, data)
rw.readBuf = data[nread:] rw.readBuf = data[nread:]
return nread, nil return nread, nil
} }
} }
func (rw *ChannelReadWriter) Write(p []byte) (n int, err error) { func (rw *ChannelReadWriter) Write(p []byte) (n int, err error) {
rw.senderMutex.Lock()
defer rw.senderMutex.Unlock()
if rw.closed { if rw.closed {
return 0, errors.New("Write on closed channel") return 0, errors.New("Write on closed channel")
} }
@ -55,7 +69,7 @@ func (rw *ChannelReadWriter) Write(p []byte) (n int, err error) {
// if context is canceled it should never write // if context is canceled it should never write
select { select {
case <-rw.ctx.Done(): case <-rw.ctx.Done():
rw.Close() //rw.Close()
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
default: default:
} }
@ -63,7 +77,7 @@ func (rw *ChannelReadWriter) Write(p []byte) (n int, err error) {
select { select {
// deal with closing duirng the write // deal with closing duirng the write
case <-rw.ctx.Done(): case <-rw.ctx.Done():
rw.Close() //rw.Close()
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
case rw.sender <- p: case rw.sender <- p:
} }
@ -71,7 +85,14 @@ func (rw *ChannelReadWriter) Write(p []byte) (n int, err error) {
} }
func (rw *ChannelReadWriter) Close() error { func (rw *ChannelReadWriter) Close() error {
rw.senderMutex.Lock()
defer rw.senderMutex.Unlock()
if !rw.closed { if !rw.closed {
log.Println("Closing ChannelReadWriter")
buf := make([]byte, 0)
runtime.Stack(buf, false)
log.Printf("Stack %v", string(buf))
close(rw.sender) close(rw.sender)
rw.closed = true rw.closed = true
} }

View File

@ -33,22 +33,18 @@ func (s *ChannelReadWriterTestSuite) createChannel() {
fromChannel := make(chan []byte, 10) fromChannel := make(chan []byte, 10)
s.toChannel = toChannel s.toChannel = toChannel
s.fromChannel = fromChannel s.fromChannel = fromChannel
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := CreateTestContext(context.Background(), 10*time.Second)
ctx, timeoutCancelFunc := context.WithTimeout(ctx, 10*time.Second)
s.ctx = ctx s.ctx = ctx
s.cancelFunc = func() { s.cancelFunc = cancelFunc
timeoutCancelFunc()
cancelFunc()
}
s.conn = NewChannelReadWriter(ctx, toChannel, fromChannel) s.conn = NewChannelReadWriter(ctx, toChannel, fromChannel)
} }
func (s *ChannelReadWriterTestSuite) SetupSuite() { func (s *ChannelReadWriterTestSuite) SetupSuite() {
s.pprofServer = startPprof("") s.pprofServer = StartPprof("")
} }
func (s *ChannelReadWriterTestSuite) TearDownSuite() { func (s *ChannelReadWriterTestSuite) TearDownSuite() {
stopPprof(s.ctx, s.pprofServer) StopPprof(s.ctx, s.pprofServer)
} }
func (s *ChannelReadWriterTestSuite) SetupTest() { func (s *ChannelReadWriterTestSuite) SetupTest() {
@ -127,7 +123,7 @@ func (s *ChannelReadWriterTestSuite) Test_SuccessfulReadWriterToChannel() {
func (s *ChannelReadWriterTestSuite) runSuccessfulReadWriteToChannel(test SuccessfulTest) func() { func (s *ChannelReadWriterTestSuite) runSuccessfulReadWriteToChannel(test SuccessfulTest) func() {
return func() { return func() {
runAndWait( RunAndWait(
&s.Suite, &s.Suite,
func() any { func() any {
for _, d := range test.data { for _, d := range test.data {
@ -163,7 +159,7 @@ type SuccessfulTest struct {
func (s *ChannelReadWriterTestSuite) runSuccessfulChannelToReadWrite(test SuccessfulTest) func() { func (s *ChannelReadWriterTestSuite) runSuccessfulChannelToReadWrite(test SuccessfulTest) func() {
return func() { return func() {
runAndWait( RunAndWait(
&s.Suite, &s.Suite,
func() any { func() any {
for _, d := range test.data { for _, d := range test.data {

View File

@ -12,9 +12,9 @@ type InmemoryConnection struct {
func NewInmemoryConnection(ctx context.Context, addr string) *InmemoryConnection { func NewInmemoryConnection(ctx context.Context, addr string) *InmemoryConnection {
pipe := InmemoryConnection{ pipe := InmemoryConnection{
ctx: ctx, ctx: ctx,
// arbitrary unbuffered channel, unbuffered is more similar to TCP connections. // TODO: somehow does not work with unbuffered channel and yamux
frontToBack: make(chan []byte), frontToBack: make(chan []byte, 0),
backToFront: make(chan []byte), backToFront: make(chan []byte, 0),
addr: addr, addr: addr,
} }
return &pipe return &pipe

View File

@ -1 +1,45 @@
package testsupport package testsupport
import (
"context"
"github.com/stretchr/testify/suite"
"net/http"
"testing"
"time"
)
type InMemoryTestSuite struct {
suite.Suite
pprofServer *http.Server
ctx context.Context
cancelFunc context.CancelFunc
pipe *InmemoryConnection
}
func TestInMemoryConnectionTestSuite(t *testing.T) {
suite.Run(t, &InMemoryTestSuite{})
}
func (s *InMemoryTestSuite) createConnection() {
ctx, cancelFunc := CreateTestContext(context.Background(), 10*time.Second)
s.ctx = ctx
s.cancelFunc = cancelFunc
s.pipe = NewInmemoryConnection(ctx, "inmemory")
}
func (s *InMemoryTestSuite) SetupSuite() {
s.pprofServer = StartPprof("")
}
func (s *InMemoryTestSuite) TearDownSuite() {
StopPprof(s.ctx, s.pprofServer)
}
func (s *InMemoryTestSuite) SetupTest() {
s.createConnection()
}
func (s *InMemoryTestSuite) TearDownTest() {
s.cancelFunc()
}

View File

@ -8,11 +8,12 @@ import (
"net/http" "net/http"
_ "runtime/pprof" _ "runtime/pprof"
"sync" "sync"
"time"
) )
type TestFunction func() any type TestFunction func() any
func runAndWait(suite *suite.Suite, functions ...TestFunction) []any { func RunAndWait(suite *suite.Suite, functions ...TestFunction) []any {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(len(functions)) wg.Add(len(functions))
res := make([]any, len(functions)) res := make([]any, len(functions))
@ -28,7 +29,7 @@ func runAndWait(suite *suite.Suite, functions ...TestFunction) []any {
return res return res
} }
func startPprof(port string) *http.Server { func StartPprof(port string) *http.Server {
if port == "" { if port == "" {
port = ":9000" port = ":9000"
} }
@ -47,7 +48,7 @@ func startPprof(port string) *http.Server {
return &srv return &srv
} }
func stopPprof(ctx context.Context, server *http.Server) { func StopPprof(ctx context.Context, server *http.Server) {
err := server.Shutdown(ctx) err := server.Shutdown(ctx)
if err != nil { if err != nil {
log.Println("Error shutting down test pprof server") log.Println("Error shutting down test pprof server")
@ -55,3 +56,13 @@ func stopPprof(ctx context.Context, server *http.Server) {
} }
log.Println("Test pprof server stopped") log.Println("Test pprof server stopped")
} }
func CreateTestContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
ctx, cancelFunc := context.WithCancel(ctx)
ctx, timeoutCancelFunc := context.WithTimeout(ctx, timeout)
compositeCancelFunc := func() {
timeoutCancelFunc()
cancelFunc()
}
return ctx, compositeCancelFunc
}