package testsupport

import (
	"bytes"
	"context"
	"fmt"
	"git.wamblee.org/converge/pkg/support/ioutils"
	"git.wamblee.org/converge/pkg/support/pprof"
	"github.com/stretchr/testify/suite"
	"io"
	"log"
	"net/http"
	"os"
	"path/filepath"
	"runtime"
	_ "runtime/pprof"
	"sync"
	"text/template"
	"time"
)

type TestFunction func() any

func RunAndWait(suite *suite.Suite, functions ...TestFunction) []any {
	wg := sync.WaitGroup{}
	wg.Add(len(functions))
	res := make([]any, len(functions))
	for i, function := range functions {
		go func() {
			defer wg.Done()
			res[i] = function()
		}()
	}
	wg.Wait()
	return res
}

func StartPprof(port string) *http.Server {
	if os.Getenv("PPROF") == "" {
		return nil
	}
	if port == "" {
		port = ":9000"
	}
	mux := http.NewServeMux()
	pprof.RegisterPprof(mux, "/debug/pprof")
	srv := http.Server{
		Addr:    port,
		Handler: mux,
	}
	go func() {
		if err := srv.ListenAndServe(); err != http.ErrServerClosed {
			log.Fatalf("Could not start pprof listener: %v", err)
		}
		log.Println("Test pprof server started: " + port)
	}()
	return &srv
}

func StopPprof(ctx context.Context, server *http.Server) {
	if os.Getenv("PPROF") == "" {
		return
	}
	err := server.Shutdown(ctx)
	if err != nil {
		log.Println("Error shutting down test pprof server")
		return
	}
	log.Println("Test pprof server stopped")
}

func CreateTestContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
	ctx, cancelFunc := context.WithCancel(ctx)
	ctx, timeoutCancelFunc := context.WithTimeout(ctx, timeout)
	compositeCancelFunc := func() {
		timeoutCancelFunc()
		cancelFunc()
	}
	return ctx, compositeCancelFunc
}

func AssertWriteData(s *suite.Suite, data string, writer io.Writer) {
	n, err := writer.Write([]byte(data))
	s.Nil(err)
	s.Equal(len(data), n)
}

func AssertReadData(s *suite.Suite, data string, reader io.Reader) {
	buf := make([]byte, len(data)+1024)
	n, err := reader.Read(buf)
	s.Nil(err)
	s.Equal(len(data), n)
	s.Equal(data, string(buf[:n]))
}

func PrintStackTraces() {
	buf := make([]byte, 100000)
	runtime.Stack(buf, true)
	log.Println("STACKTRACE")
	log.Println("")
	log.Println(string(buf))
	log.Println("")
}

func BidirectionalConnectionCheck(s *suite.Suite, msg string, clientToServerRW io.ReadWriteCloser, agentToServerYamux io.ReadWriter) {
	data1 := msg + " -> "
	data2 := msg + " <- "
	RunAndWait(
		s,
		func() any {
			AssertWriteData(s, data1, clientToServerRW)
			AssertReadData(s, data2, clientToServerRW)
			return nil
		},
		func() any {
			AssertReadData(s, data1, agentToServerYamux)
			AssertWriteData(s, data2, agentToServerYamux)
			return nil
		})
}

// having the return type bool forces the check to be done in the test code
// leading to more clear error messages.
func CheckCondition(ctx context.Context, condition func() bool) bool {
	for !condition() {
		select {
		case <-ctx.Done():
			return false
		default:
			time.Sleep(1 * time.Millisecond)
		}
	}
	return true
}

func Template(templateString string, data any) string {
	tmpl, err := template.New("dummy").Parse(templateString)
	if err != nil {
		panic(err)
	}
	buf := bytes.Buffer{}
	err = tmpl.Execute(&buf, data)
	if err != nil {
		panic(err)
	}
	return buf.String()
}

func GetGoModDIr() string {
	curDir, err := os.Getwd()
	if err != nil {
		panic(fmt.Sprintf("Unable to get the current directory %v", err))
	}
	for mod := filepath.Join(curDir, "go.mod"); !ioutils.FileExists(mod); mod = filepath.Join(curDir, "go.mod") {
		newCurDir := filepath.Dir(curDir)
		if newCurDir == curDir {
			panic("Could not find top-level directory of converge module")
		}
		curDir = newCurDir
	}

	return curDir
}