converge/pkg/testsupport/utils.go

164 lines
3.6 KiB
Go

package testsupport
import (
"bytes"
"context"
"fmt"
"git.wamblee.org/converge/pkg/support/ioutils"
"git.wamblee.org/converge/pkg/support/pprof"
"github.com/stretchr/testify/suite"
"io"
"log"
"net/http"
"os"
"path/filepath"
"runtime"
_ "runtime/pprof"
"sync"
"text/template"
"time"
)
type TestFunction func() any
func RunAndWait(suite *suite.Suite, functions ...TestFunction) []any {
wg := sync.WaitGroup{}
wg.Add(len(functions))
res := make([]any, len(functions))
for i, function := range functions {
go func() {
defer wg.Done()
res[i] = function()
}()
}
wg.Wait()
return res
}
func StartPprof(port string) *http.Server {
if os.Getenv("PPROF") == "" {
return nil
}
if port == "" {
port = ":9000"
}
mux := http.NewServeMux()
pprof.RegisterPprof(mux, "/debug/pprof")
srv := http.Server{
Addr: port,
Handler: mux,
}
go func() {
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("Could not start pprof listener: %v", err)
}
log.Println("Test pprof server started: " + port)
}()
return &srv
}
func StopPprof(ctx context.Context, server *http.Server) {
if os.Getenv("PPROF") == "" {
return
}
err := server.Shutdown(ctx)
if err != nil {
log.Println("Error shutting down test pprof server")
return
}
log.Println("Test pprof server stopped")
}
func CreateTestContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
ctx, cancelFunc := context.WithCancel(ctx)
ctx, timeoutCancelFunc := context.WithTimeout(ctx, timeout)
compositeCancelFunc := func() {
timeoutCancelFunc()
cancelFunc()
}
return ctx, compositeCancelFunc
}
func AssertWriteData(s *suite.Suite, data string, writer io.Writer) {
n, err := writer.Write([]byte(data))
s.Nil(err)
s.Equal(len(data), n)
}
func AssertReadData(s *suite.Suite, data string, reader io.Reader) {
buf := make([]byte, len(data)+1024)
n, err := reader.Read(buf)
s.Nil(err)
s.Equal(len(data), n)
s.Equal(data, string(buf[:n]))
}
func PrintStackTraces() {
buf := make([]byte, 100000)
runtime.Stack(buf, true)
log.Println("STACKTRACE")
log.Println("")
log.Println(string(buf))
log.Println("")
}
func BidirectionalConnectionCheck(s *suite.Suite, msg string, clientToServerRW io.ReadWriteCloser, agentToServerYamux io.ReadWriter) {
data1 := msg + " -> "
data2 := msg + " <- "
RunAndWait(
s,
func() any {
AssertWriteData(s, data1, clientToServerRW)
AssertReadData(s, data2, clientToServerRW)
return nil
},
func() any {
AssertReadData(s, data1, agentToServerYamux)
AssertWriteData(s, data2, agentToServerYamux)
return nil
})
}
// having the return type bool forces the check to be done in the test code
// leading to more clear error messages.
func CheckCondition(ctx context.Context, condition func() bool) bool {
for !condition() {
select {
case <-ctx.Done():
return false
default:
time.Sleep(1 * time.Millisecond)
}
}
return true
}
func Template(templateString string, data any) string {
tmpl, err := template.New("dummy").Parse(templateString)
if err != nil {
panic(err)
}
buf := bytes.Buffer{}
err = tmpl.Execute(&buf, data)
if err != nil {
panic(err)
}
return buf.String()
}
func GetGoModDIr() string {
curDir, err := os.Getwd()
if err != nil {
panic(fmt.Sprintf("Unable to get the current directory %v", err))
}
for mod := filepath.Join(curDir, "go.mod"); !ioutils.FileExists(mod); mod = filepath.Join(curDir, "go.mod") {
newCurDir := filepath.Dir(curDir)
if newCurDir == curDir {
panic("Could not find top-level directory of converge module")
}
curDir = newCurDir
}
return curDir
}