165 lines
3.6 KiB
Go
165 lines
3.6 KiB
Go
package testsupport
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"git.wamblee.org/converge/pkg/support/ioutils"
|
|
"git.wamblee.org/converge/pkg/support/pprof"
|
|
"github.com/stretchr/testify/suite"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
_ "runtime/pprof"
|
|
"sync"
|
|
"text/template"
|
|
"time"
|
|
)
|
|
|
|
type TestFunction func() any
|
|
|
|
func RunAndWait(suite *suite.Suite, functions ...TestFunction) []any {
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(len(functions))
|
|
res := make([]any, len(functions))
|
|
for i, function := range functions {
|
|
go func() {
|
|
defer wg.Done()
|
|
res[i] = function()
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
return res
|
|
}
|
|
|
|
func StartPprof(port string) *http.Server {
|
|
if os.Getenv("PPROF") == "" {
|
|
return nil
|
|
}
|
|
if port == "" {
|
|
port = ":9000"
|
|
}
|
|
mux := http.NewServeMux()
|
|
pprof.RegisterPprof(mux, "/debug/pprof")
|
|
srv := http.Server{
|
|
Addr: port,
|
|
Handler: mux,
|
|
}
|
|
go func() {
|
|
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
|
|
log.Fatalf("Could not start pprof listener: %v", err)
|
|
}
|
|
log.Println("Test pprof server started: " + port)
|
|
}()
|
|
return &srv
|
|
}
|
|
|
|
func StopPprof(ctx context.Context, server *http.Server) {
|
|
if os.Getenv("PPROF") == "" {
|
|
return
|
|
}
|
|
err := server.Shutdown(ctx)
|
|
if err != nil {
|
|
log.Println("Error shutting down test pprof server")
|
|
return
|
|
}
|
|
log.Println("Test pprof server stopped")
|
|
}
|
|
|
|
func CreateTestContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
|
ctx, cancelFunc := context.WithCancel(ctx)
|
|
ctx, timeoutCancelFunc := context.WithTimeout(ctx, timeout)
|
|
compositeCancelFunc := func() {
|
|
timeoutCancelFunc()
|
|
cancelFunc()
|
|
}
|
|
return ctx, compositeCancelFunc
|
|
}
|
|
|
|
func AssertWriteData(s *suite.Suite, data string, writer io.Writer) {
|
|
n, err := writer.Write([]byte(data))
|
|
s.Nil(err)
|
|
s.Equal(len(data), n)
|
|
}
|
|
|
|
func AssertReadData(s *suite.Suite, data string, reader io.Reader) {
|
|
buf := make([]byte, len(data)+1024)
|
|
n, err := reader.Read(buf)
|
|
s.Nil(err)
|
|
s.Equal(len(data), n)
|
|
s.Equal(data, string(buf[:n]))
|
|
}
|
|
|
|
func PrintStackTraces() {
|
|
buf := make([]byte, 100000)
|
|
runtime.Stack(buf, true)
|
|
log.Println("STACKTRACE")
|
|
log.Println("")
|
|
log.Println(string(buf))
|
|
log.Println("")
|
|
}
|
|
|
|
func BidirectionalConnectionCheck(s *suite.Suite, msg string, clientToServerRW io.ReadWriteCloser, agentToServerYamux io.ReadWriter) {
|
|
data1 := msg + " -> "
|
|
data2 := msg + " <- "
|
|
log.Printf("BIDIRECTIONAL CHECK %v", msg)
|
|
RunAndWait(
|
|
s,
|
|
func() any {
|
|
AssertWriteData(s, data1, clientToServerRW)
|
|
AssertReadData(s, data2, clientToServerRW)
|
|
return nil
|
|
},
|
|
func() any {
|
|
AssertReadData(s, data1, agentToServerYamux)
|
|
AssertWriteData(s, data2, agentToServerYamux)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// having the return type bool forces the check to be done in the test code
|
|
// leading to more clear error messages.
|
|
func CheckCondition(ctx context.Context, condition func() bool) bool {
|
|
for !condition() {
|
|
select {
|
|
case <-ctx.Done():
|
|
return false
|
|
default:
|
|
time.Sleep(1 * time.Millisecond)
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func Template(templateString string, data any) string {
|
|
tmpl, err := template.New("dummy").Parse(templateString)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
buf := bytes.Buffer{}
|
|
err = tmpl.Execute(&buf, data)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return buf.String()
|
|
}
|
|
|
|
func GetGoModDIr() string {
|
|
curDir, err := os.Getwd()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Unable to get the current directory %v", err))
|
|
}
|
|
for mod := filepath.Join(curDir, "go.mod"); !ioutils.FileExists(mod); mod = filepath.Join(curDir, "go.mod") {
|
|
newCurDir := filepath.Dir(curDir)
|
|
if newCurDir == curDir {
|
|
panic("Could not find top-level directory of converge module")
|
|
}
|
|
curDir = newCurDir
|
|
}
|
|
|
|
return curDir
|
|
}
|