package prometheus

import (
	"git.wamblee.org/converge/pkg/models"
	"git.wamblee.org/converge/pkg/support/collections"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/client_golang/prometheus/promauto"
	"github.com/prometheus/client_golang/prometheus/promhttp"
	"log"
	"net/http"
	"time"
)

const NAMESPACE = "converge"

// more efficient state representation for state
type PrometheusState struct {
	agents  *collections.LinkedMap[models.AgentGuid, *models.Agent]
	clients *collections.LinkedMap[models.ClientGuid, *models.Client]
}

func NewPrometheusState(state *models.State) *PrometheusState {
	res := PrometheusState{
		agents:  collections.NewLinkedMap[models.AgentGuid, *models.Agent](),
		clients: collections.NewLinkedMap[models.ClientGuid, *models.Client](),
	}
	for _, agent := range state.Agents {
		res.agents.Put(agent.Guid, agent)
	}
	for _, client := range state.Clients {
		res.clients.Put(client.Guid, client)
	}
	return &res
}

var (
	// remember previous values of agent guids and clients so that we can increment
	// the cumulative counters.
	lastState *PrometheusState = NewPrometheusState(models.NewState())

	cumulativeAgentCount = promauto.NewCounter(prometheus.CounterOpts{
		Namespace: NAMESPACE,
		Name:      "agent_count_total",
		Help:      "Total number of agents connected over time",
	})
	cumulativeClientCount = promauto.NewCounter(prometheus.CounterOpts{
		Namespace: NAMESPACE,
		Name:      "client_count_total",
		Help:      "Total number of clients connected over time",
	})

	agentCount = promauto.NewGauge(prometheus.GaugeOpts{
		Namespace: NAMESPACE,
		Name:      "agent_count",
		Help:      "Current number of agents",
	})
	clientCount = promauto.NewGauge(prometheus.GaugeOpts{
		Namespace: NAMESPACE,
		Name:      "client_count",
		Help:      "Current number of clients",
	})

	agentStartTime = promauto.NewGaugeVec(prometheus.GaugeOpts{
		Namespace: NAMESPACE,
		Name:      "agent_start_time_millis",
		Help:      "Time the agent started",
	}, []string{"agent_guid"})
	clientStartTime = promauto.NewGaugeVec(prometheus.GaugeOpts{
		Namespace: NAMESPACE,
		Name:      "client_start_time_millis",
		Help:      "Time the client started",
	}, []string{"client_guid"})

	agentDuration = promauto.NewGaugeVec(prometheus.GaugeOpts{
		Namespace: NAMESPACE,
		Name:      "agent_duration_seconds",
		Help:      "Time the agent is already running",
	}, []string{"agent_guid"})
	clientDuration = promauto.NewGaugeVec(prometheus.GaugeOpts{
		Namespace: NAMESPACE,
		Name:      "client_duration_seconds",
		Help:      "Time the client is already running",
	}, []string{"client_guid"})

	agentInfo = promauto.NewGaugeVec(
		prometheus.GaugeOpts{
			Namespace: NAMESPACE,
			Name:      "agent_info",
			Help:      "A flexible gauge with dynamic labels, always set to 1",
		},
		[]string{
			"agent_guid",
			"agent_address",
			"agent_id",
			"agent_username",
			"agent_hostname",
			"agent_pwd",
			"agent_os",
			"agent_shell",
		})

	clientInfo = promauto.NewGaugeVec(
		prometheus.GaugeOpts{
			Namespace: NAMESPACE,
			Name:      "client_info",
			Help:      "A flexible gauge with dynamic labels, always set to 1",
		},
		[]string{"client_guid",
			"client_address",
			"client_id",
			"agent_id",
			"agent_guid",
			"client_sessiontype",
			"client_username",
			"client_hostname",
			"client_pwd",
			"client_os",
			"client_shell",
		}, // Label names
	)
)

func agentLabels(agent *models.Agent) prometheus.Labels {
	return prometheus.Labels{
		"agent_guid":     string(agent.Guid),
		"agent_address":  string(agent.RemoteAddr),
		"agent_id":       string(agent.PublicId),
		"agent_username": agent.EnvironmentInfo.Username,
		"agent_hostname": agent.EnvironmentInfo.Hostname,
		"agent_pwd":      agent.EnvironmentInfo.Pwd,
		"agent_os":       agent.EnvironmentInfo.OS,
		"agent_shell":    agent.EnvironmentInfo.Shell,
	}
}

func clientLabels(client *models.Client) prometheus.Labels {
	return prometheus.Labels{
		"client_guid":        string(client.Guid),
		"client_address":     string(client.RemoteAddr),
		"client_id":          string(client.ClientId),
		"agent_id":           string(client.PublicId),
		"agent_guid":         string(client.AgentGuid),
		"client_sessiontype": string(client.SessionType),
		"client_username":    client.EnvironmentInfo.Username,
		"client_hostname":    client.EnvironmentInfo.Hostname,
		"client_pwd":         client.EnvironmentInfo.Pwd,
		"client_os":          client.EnvironmentInfo.OS,
		"client_shell":       client.EnvironmentInfo.Shell,
	}
}

func agentActive(agent *models.Agent) {
	prevAgent, ok := lastState.agents.Get(agent.Guid)
	if ok && *prevAgent != *agent {
		removeAgentInfoMetrics(prevAgent)
	}
	agentInfo.With(agentLabels(agent)).Set(1)
	agentGuid := string(agent.Guid)
	agentStartTime.
		With(prometheus.Labels{"agent_guid": agentGuid}).
		Set(float64(agent.StartTime.UnixMilli()))
	agentDuration.
		With(prometheus.Labels{"agent_guid": agentGuid}).
		Set(float64(time.Now().Sub(agent.StartTime).Seconds()))
}

func clientActive(client *models.Client) {
	prevClient, ok := lastState.clients.Get(client.Guid)
	if ok && *prevClient != *client {
		removeClientInfoMetrics(prevClient)
	}
	clientInfo.With(clientLabels(client)).Set(1)
	clientGuid := string(client.Guid)
	clientStartTime.
		With(prometheus.Labels{"client_guid": clientGuid}).
		Set(float64(client.StartTime.UnixMilli()))
	clientDuration.
		With(prometheus.Labels{"client_guid": clientGuid}).
		Set(float64(time.Now().Sub(client.StartTime).Seconds()))
}

func SetupPrometheus(mux *http.ServeMux, notifications chan *models.State) {
	// go routine that handles incoming events so we don't need to serialize access in some
	// other way.
	go func() {
		for task := range prometheusChannel {
			task()
		}
	}()

	// send an event periodically to update the agent and client durations so
	// prometheus gets accurate values.
	go func() {
		timer := time.NewTicker(1 * time.Second)
		for {
			select {
			case <-timer.C:
				prometheusChannel <- updateDurations
			}
		}
	}()

	// process incoming notifications from converge to update metrics.
	go func() {
		for {
			state := <-notifications
			updateMetrics(state)
		}
	}()

	// expose prometheus on a separate port.
	mux.Handle("/metrics", promhttp.Handler())
}

var prometheusChannel = make(chan func())

// serialize notifidcations and periodi updates of the duration.

func updateMetrics(state *models.State) {
	prometheusChannel <- func() {
		updateMetricsImpl(NewPrometheusState(state))
	}
}

func updateDurations() {
	for agent := range lastState.agents.RangeValues() {
		agentDuration.
			With(prometheus.Labels{"agent_guid": string(agent.Guid)}).
			Set(float64(time.Now().Sub(agent.StartTime).Seconds()))
	}
	for client := range lastState.clients.RangeValues() {
		clientDuration.
			With(prometheus.Labels{"client_guid": string(client.Guid)}).
			Set(float64(time.Now().Sub(client.StartTime).Seconds()))
	}
}

func updateMetricsImpl(state *PrometheusState) {

	agentGuids := make(map[models.AgentGuid]*models.Agent)
	clientGuids := make(map[models.ClientGuid]*models.Client)

	agentCount.Set(float64(state.agents.Len()))
	disconnectedAgents := make(map[models.AgentGuid]*models.Agent)
	for agent := range lastState.agents.RangeValues() {
		disconnectedAgents[agent.Guid] = agent
	}
	for agent := range state.agents.RangeValues() {
		if lastState.agents.Contains(agent.Guid) {
			cumulativeAgentCount.Inc()
		}
		delete(disconnectedAgents, agent.Guid)
		agentGuids[agent.Guid] = agent
		agentActive(agent)
	}
	for _, agent := range disconnectedAgents {
		removeAgentMetrics(agent)
	}

	clientCount.Set(float64(state.clients.Len()))

	// with this app
	disconnectedClients := make(map[models.ClientGuid]*models.Client)
	for client := range lastState.clients.RangeValues() {
		disconnectedClients[client.Guid] = client
	}
	for client := range state.clients.RangeValues() {
		if lastState.clients.Contains(client.Guid) {
			cumulativeClientCount.Inc()
		}
		delete(disconnectedClients, client.Guid)
		clientGuids[client.Guid] = client
		clientActive(client)
	}
	for _, client := range disconnectedClients {
		removeClientMetrics(client)
	}

	lastState = state
}

func removeAgentInfoMetrics(agent *models.Agent) bool {
	return agentInfo.Delete(agentLabels(agent))
}

func removeAgentMetrics(agent *models.Agent) {
	ok1 := removeAgentInfoMetrics(agent)
	guidLabels := prometheus.Labels{"agent_guid": string(agent.Guid)}
	ok2 := agentStartTime.Delete(guidLabels)
	// delayed deletion of the duration sow we are sure the prometheus has the last data.
	go func() {
		time.Sleep(60 * time.Second)
		ok := agentDuration.Delete(guidLabels)
		if !ok {
			log.Printf("Could not delete duration timeseries for agent %s", agent.Guid)
		}
	}()
	if !ok1 || !ok2 {
		log.Printf("Could not delete all timeseries for agent %s (info %v, starttime %v) ",
			agent.Guid, ok1, ok2)
	}
}

func removeClientInfoMetrics(client *models.Client) bool {
	return clientInfo.Delete(clientLabels(client))
}

func removeClientMetrics(client *models.Client) {
	ok1 := removeClientInfoMetrics(client)
	guidLabels := prometheus.Labels{"client_guid": string(client.Guid)}
	ok2 := clientStartTime.Delete(guidLabels)
	// delayed deletion of the duration sow we are sure the prometheus has the last data.
	go func() {
		time.Sleep(60 * time.Second)
		ok := clientDuration.Delete(guidLabels)
		if !ok {
			log.Printf("Could not delete duration timeseries for client %s", client.Guid)
		}
	}()
	if !ok1 || !ok2 {
		log.Printf("Could not delete all timeseries for client %s (info %v, starttime %v)", client.Guid, ok1, ok2)
	}
}