now using maps of Guid to Agent/Client in the state, working towards the definitive solution.
Using LinkedMap that preserves insertion order for the implementation and also added unit tests for that.
This commit is contained in:
parent
98f6b414de
commit
fd18a63360
3
Makefile
3
Makefile
@ -13,6 +13,9 @@ generate:
|
||||
vet: fmt
|
||||
go vet ./...
|
||||
|
||||
test:
|
||||
go test -v ./...
|
||||
|
||||
build: generate vet
|
||||
mkdir -p bin
|
||||
go build -o bin ./cmd/...
|
||||
|
@ -15,6 +15,6 @@ func NewStateNotifier() *StateNotifier {
|
||||
}
|
||||
|
||||
func (notifier StateNotifier) Publish(state *models.State) {
|
||||
notifier.webNotificationChannel <- state.Copy()
|
||||
notifier.prometheusNotificationChannel <- state.Copy()
|
||||
notifier.webNotificationChannel <- state
|
||||
notifier.prometheusNotificationChannel <- state
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"git.wamblee.org/converge/pkg/models"
|
||||
"git.wamblee.org/converge/pkg/support/collections"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
@ -14,20 +15,20 @@ const NAMESPACE = "converge"
|
||||
|
||||
// more efficient state representation for state
|
||||
type PrometheusState struct {
|
||||
agents map[models.AgentGuid]*models.Agent
|
||||
clients map[models.ClientGuid]*models.Client
|
||||
agents *collections.LinkedMap[models.AgentGuid, *models.Agent]
|
||||
clients *collections.LinkedMap[models.ClientGuid, *models.Client]
|
||||
}
|
||||
|
||||
func NewPrometheusState(state *models.State) *PrometheusState {
|
||||
res := PrometheusState{
|
||||
agents: make(map[models.AgentGuid]*models.Agent),
|
||||
clients: make(map[models.ClientGuid]*models.Client),
|
||||
agents: collections.NewLinkedMap[models.AgentGuid, *models.Agent](),
|
||||
clients: collections.NewLinkedMap[models.ClientGuid, *models.Client](),
|
||||
}
|
||||
for i, _ := range state.Agents {
|
||||
res.agents[state.Agents[i].Guid] = &state.Agents[i]
|
||||
for agent := range state.Agents.RangeValues() {
|
||||
res.agents.Put(agent.Guid, agent)
|
||||
}
|
||||
for i, _ := range state.Clients {
|
||||
res.clients[state.Clients[i].Guid] = &state.Clients[i]
|
||||
for client := range state.Clients.RangeValues() {
|
||||
res.clients.Put(client.Guid, client)
|
||||
}
|
||||
return &res
|
||||
}
|
||||
@ -35,7 +36,7 @@ func NewPrometheusState(state *models.State) *PrometheusState {
|
||||
var (
|
||||
// remember previous values of agent guids and clients so that we can increment
|
||||
// the cumulative counters.
|
||||
lastState *PrometheusState = NewPrometheusState(&models.State{})
|
||||
lastState *PrometheusState = NewPrometheusState(models.NewState())
|
||||
|
||||
cumulativeAgentCount = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: NAMESPACE,
|
||||
@ -149,7 +150,7 @@ func clientLabels(client *models.Client) prometheus.Labels {
|
||||
}
|
||||
|
||||
func agentActive(agent *models.Agent) {
|
||||
prevAgent, ok := lastState.agents[agent.Guid]
|
||||
prevAgent, ok := lastState.agents.Get(agent.Guid)
|
||||
if ok && *prevAgent != *agent {
|
||||
removeAgentInfoMetrics(prevAgent)
|
||||
}
|
||||
@ -164,7 +165,7 @@ func agentActive(agent *models.Agent) {
|
||||
}
|
||||
|
||||
func clientActive(client *models.Client) {
|
||||
prevClient, ok := lastState.clients[client.Guid]
|
||||
prevClient, ok := lastState.clients.Get(client.Guid)
|
||||
if ok && *prevClient != *client {
|
||||
removeClientInfoMetrics(prevClient)
|
||||
}
|
||||
@ -222,12 +223,12 @@ func updateMetrics(state *models.State) {
|
||||
}
|
||||
|
||||
func updateDurations() {
|
||||
for _, agent := range lastState.agents {
|
||||
for agent := range lastState.agents.RangeValues() {
|
||||
agentDuration.
|
||||
With(prometheus.Labels{"agent_guid": string(agent.Guid)}).
|
||||
Set(float64(time.Now().Sub(agent.StartTime).Seconds()))
|
||||
}
|
||||
for _, client := range lastState.clients {
|
||||
for client := range lastState.clients.RangeValues() {
|
||||
clientDuration.
|
||||
With(prometheus.Labels{"client_guid": string(client.Guid)}).
|
||||
Set(float64(time.Now().Sub(client.StartTime).Seconds()))
|
||||
@ -239,13 +240,13 @@ func updateMetricsImpl(state *PrometheusState) {
|
||||
agentGuids := make(map[models.AgentGuid]*models.Agent)
|
||||
clientGuids := make(map[models.ClientGuid]*models.Client)
|
||||
|
||||
agentCount.Set(float64(len(state.agents)))
|
||||
agentCount.Set(float64(state.agents.Len()))
|
||||
disconnectedAgents := make(map[models.AgentGuid]*models.Agent)
|
||||
for _, agent := range lastState.agents {
|
||||
for agent := range lastState.agents.RangeValues() {
|
||||
disconnectedAgents[agent.Guid] = agent
|
||||
}
|
||||
for _, agent := range state.agents {
|
||||
if lastState.agents[agent.Guid] == nil {
|
||||
for agent := range state.agents.RangeValues() {
|
||||
if lastState.agents.Contains(agent.Guid) {
|
||||
cumulativeAgentCount.Inc()
|
||||
}
|
||||
delete(disconnectedAgents, agent.Guid)
|
||||
@ -256,15 +257,15 @@ func updateMetricsImpl(state *PrometheusState) {
|
||||
removeAgentMetrics(agent)
|
||||
}
|
||||
|
||||
clientCount.Set(float64(len(state.clients)))
|
||||
clientCount.Set(float64(state.clients.Len()))
|
||||
|
||||
// with this app
|
||||
disconnectedClients := make(map[models.ClientGuid]*models.Client)
|
||||
for _, client := range lastState.clients {
|
||||
for client := range lastState.clients.RangeValues() {
|
||||
disconnectedClients[client.Guid] = client
|
||||
}
|
||||
for _, client := range state.clients {
|
||||
if lastState.clients[client.Guid] == nil {
|
||||
for client := range state.clients.RangeValues() {
|
||||
if lastState.clients.Contains(client.Guid) {
|
||||
cumulativeClientCount.Inc()
|
||||
}
|
||||
delete(disconnectedClients, client.Guid)
|
||||
|
@ -79,7 +79,7 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
state := models.State{}
|
||||
state := models.NewState()
|
||||
agent := models.Agent{
|
||||
Guid: models.AgentGuid(strconv.Itoa(rand.Int())),
|
||||
RemoteAddr: "10.220.1.3:3333",
|
||||
@ -94,7 +94,7 @@ func main() {
|
||||
},
|
||||
ExpiryTime: time.Now().In(japan).Add(10 * time.Minute),
|
||||
}
|
||||
state.Agents = append(state.Agents, agent)
|
||||
state.Agents.Put(agent.Guid, &agent)
|
||||
client := models.Client{
|
||||
Guid: models.ClientGuid(strconv.Itoa(rand.Int())),
|
||||
RemoteAddr: models.RemoteAddr("10.1.3.3"),
|
||||
@ -104,7 +104,7 @@ func main() {
|
||||
StartTime: time.Now().In(japan),
|
||||
SessionType: models.SessionType("sftp"),
|
||||
}
|
||||
state.Clients = append(state.Clients, client)
|
||||
return templates2.SessionsTab(&state, netherlands)
|
||||
state.Clients.Put(client.Guid, &client)
|
||||
return templates2.SessionsTab(state, netherlands)
|
||||
})
|
||||
}
|
||||
|
5
go.mod
5
go.mod
@ -12,6 +12,7 @@ require (
|
||||
github.com/hashicorp/yamux v0.1.1
|
||||
github.com/pkg/sftp v1.13.6
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/crypto v0.25.0
|
||||
golang.org/x/term v0.22.0
|
||||
)
|
||||
@ -21,10 +22,14 @@ require (
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.5.0 // indirect
|
||||
github.com/prometheus/common v0.48.0 // indirect
|
||||
github.com/prometheus/procfs v0.12.0 // indirect
|
||||
golang.org/x/sys v0.22.0 // indirect
|
||||
google.golang.org/protobuf v1.33.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
13
go.sum
13
go.sum
@ -10,6 +10,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0=
|
||||
github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@ -27,6 +28,10 @@ github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE
|
||||
github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
|
||||
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pkg/sftp v1.13.6 h1:JFZT4XbOU7l77xGSpOdW+pwIMqP044IyjXX6FGyEKFo=
|
||||
github.com/pkg/sftp v1.13.6/go.mod h1:tz1ryNURKu77RL+GuCzmoJYxQczL3wLNNpPWagdg4Qk=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@ -39,12 +44,14 @@ github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSz
|
||||
github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
|
||||
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
|
||||
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
@ -83,6 +90,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
@ -2,6 +2,7 @@ package models
|
||||
|
||||
import (
|
||||
"git.wamblee.org/converge/pkg/comms"
|
||||
"git.wamblee.org/converge/pkg/support/collections"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -39,15 +40,13 @@ type Client struct {
|
||||
// Created by the server and used for updating the web client
|
||||
// and prometheus metrics.
|
||||
type State struct {
|
||||
Agents []Agent
|
||||
Clients []Client
|
||||
Agents *collections.LinkedMap[AgentGuid, *Agent]
|
||||
Clients *collections.LinkedMap[ClientGuid, *Client]
|
||||
}
|
||||
|
||||
func (state *State) Copy() *State {
|
||||
c := State{}
|
||||
c.Agents = make([]Agent, len(state.Agents))
|
||||
c.Clients = make([]Client, len(state.Clients))
|
||||
copy(c.Agents, state.Agents)
|
||||
copy(c.Clients, state.Clients)
|
||||
return &c
|
||||
func NewState() *State {
|
||||
return &State{
|
||||
Agents: collections.NewLinkedMap[AgentGuid, *Agent](),
|
||||
Clients: collections.NewLinkedMap[ClientGuid, *Client](),
|
||||
}
|
||||
}
|
||||
|
@ -81,16 +81,14 @@ func NewAdmin() *Admin {
|
||||
func (admin *Admin) CreateNotifification() *models.State {
|
||||
admin.mutex.Lock()
|
||||
defer admin.mutex.Unlock()
|
||||
state := models.State{}
|
||||
state.Agents = make([]models.Agent, 0, len(admin.agents))
|
||||
state.Clients = make([]models.Client, 0, len(admin.clients))
|
||||
state := models.NewState()
|
||||
for _, agent := range admin.agents {
|
||||
state.Agents = append(state.Agents, agent.Agent)
|
||||
state.Agents.Put(agent.Guid, &agent.Agent)
|
||||
}
|
||||
for _, client := range admin.clients {
|
||||
state.Clients = append(state.Clients, client.Client)
|
||||
state.Clients.Put(client.Guid, &client.Client)
|
||||
}
|
||||
return &state
|
||||
return state
|
||||
}
|
||||
|
||||
func (admin *Admin) getFreeId(publicId models.RendezVousId) (models.RendezVousId, error) {
|
||||
|
@ -134,7 +134,7 @@ func logStatusImpl(admin *models.State, notifier Notifier) {
|
||||
|
||||
lines = append(lines, fmt.Sprintf(format, "AGENT", "ACTIVE_SINCE", "EXPIRY_TIME",
|
||||
"USER", "HOST", "OS"))
|
||||
for _, agent := range admin.Agents {
|
||||
for agent := range admin.Agents.RangeValues() {
|
||||
lines = append(lines, fmt.Sprintf(format, agent.PublicId,
|
||||
agent.StartTime.Format(time.DateTime),
|
||||
agent.ExpiryTime.Format(time.DateTime),
|
||||
@ -145,7 +145,7 @@ func logStatusImpl(admin *models.State, notifier Notifier) {
|
||||
lines = append(lines, "")
|
||||
format = "%-10s %-20s %-20s %-20s %-20s"
|
||||
lines = append(lines, fmt.Sprintf(format, "CLIENT", "AGENT", "ACTIVE_SINCE", "REMOTE_ADDRESS", "SESSION_TYPE"))
|
||||
for _, client := range admin.Clients {
|
||||
for client := range admin.Clients.RangeValues() {
|
||||
lines = append(lines, fmt.Sprintf(format,
|
||||
client.ClientId,
|
||||
client.PublicId,
|
||||
|
@ -123,7 +123,7 @@ func (session *WebSession) WriteNotifications(location *time.Location, ctx conte
|
||||
log.Println("channel closed")
|
||||
return
|
||||
}
|
||||
if session.writeNotificationToClient(location, notification) {
|
||||
if !session.writeNotificationToClient(location, notification) {
|
||||
return
|
||||
}
|
||||
case <-timer.C:
|
||||
@ -140,9 +140,9 @@ func (session *WebSession) writeNotificationToClient(location *time.Location, no
|
||||
err := templates.State(notification, location).Render(context.Background(), session.conn)
|
||||
if err != nil {
|
||||
log.Printf("WS connection closed: %v", err)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (sessions *WebSessions) SessionClosed(session *WebSession) {
|
||||
|
@ -31,7 +31,7 @@ templ State(state *models.State, location *time.Location) {
|
||||
|
||||
<h3>agents</h3>
|
||||
|
||||
if len(state.Agents) == 0 {
|
||||
if state.Agents.Len() == 0 {
|
||||
<p>-</p>
|
||||
} else {
|
||||
<table class="table">
|
||||
@ -46,7 +46,7 @@ templ State(state *models.State, location *time.Location) {
|
||||
<th>shell</th>
|
||||
</tr>
|
||||
</thead>
|
||||
for _, agent := range state.Agents {
|
||||
for agent := range state.Agents.RangeValues() {
|
||||
<tr>
|
||||
<td>{string(agent.PublicId)}</td>
|
||||
<td>{agent.StartTime.In(location).Format(time.DateTime)}</td>
|
||||
@ -64,7 +64,7 @@ templ State(state *models.State, location *time.Location) {
|
||||
|
||||
<h3>clients</h3>
|
||||
|
||||
if len(state.Clients) == 0 {
|
||||
if state.Clients.Len() == 0 {
|
||||
<p>-</p>
|
||||
} else {
|
||||
<table class="table">
|
||||
@ -80,7 +80,7 @@ templ State(state *models.State, location *time.Location) {
|
||||
<th>shell</th>
|
||||
</tr>
|
||||
</thead>
|
||||
for _, client := range state.Clients {
|
||||
for client := range state.Clients.RangeValues() {
|
||||
<tr>
|
||||
<td>{string(client.ClientId)}</td>
|
||||
<td>{client.StartTime.In(location).Format(time.DateTime)}</td>
|
||||
|
159
pkg/support/collections/linkedmap.go
Normal file
159
pkg/support/collections/linkedmap.go
Normal file
@ -0,0 +1,159 @@
|
||||
package collections
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// similar to linkes hash map in Java, a map that preserves insertion order
|
||||
|
||||
var checkStatus = false
|
||||
|
||||
type Node[K comparable, V any] struct {
|
||||
key K
|
||||
value V
|
||||
prev *Node[K, V]
|
||||
next *Node[K, V]
|
||||
}
|
||||
|
||||
type LinkedMap[K comparable, V any] struct {
|
||||
first *Node[K, V]
|
||||
last *Node[K, V]
|
||||
collection map[K]*Node[K, V]
|
||||
}
|
||||
|
||||
func NewLinkedMap[K comparable, V any]() *LinkedMap[K, V] {
|
||||
res := LinkedMap[K, V]{
|
||||
first: nil,
|
||||
last: nil,
|
||||
collection: make(map[K]*Node[K, V]),
|
||||
}
|
||||
res.check()
|
||||
return &res
|
||||
}
|
||||
|
||||
func (m *LinkedMap[K, V]) Len() int {
|
||||
return len(m.collection)
|
||||
}
|
||||
|
||||
func (m *LinkedMap[K, V]) Put(key K, value V) {
|
||||
defer m.check()
|
||||
newNode := &Node[K, V]{
|
||||
key: key,
|
||||
value: value,
|
||||
prev: m.last,
|
||||
next: nil,
|
||||
}
|
||||
if m.first == nil {
|
||||
m.first = newNode
|
||||
m.last = m.first
|
||||
m.collection[key] = m.first
|
||||
return
|
||||
}
|
||||
m.Delete(key)
|
||||
m.last.next = newNode
|
||||
m.last = newNode
|
||||
m.collection[key] = newNode
|
||||
}
|
||||
|
||||
func (m *LinkedMap[K, V]) Delete(key K) bool {
|
||||
defer m.check()
|
||||
node, ok := m.collection[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if node.prev != nil {
|
||||
node.prev.next = node.next
|
||||
} else {
|
||||
m.first = node.next
|
||||
}
|
||||
if node.next != nil {
|
||||
node.next.prev = node.prev
|
||||
} else {
|
||||
m.last = node.prev
|
||||
}
|
||||
delete(m.collection, key)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *LinkedMap[K, V]) Get(key K) (V, bool) {
|
||||
defer m.check()
|
||||
v, ok := m.collection[key]
|
||||
if !ok {
|
||||
return *new(V), false
|
||||
}
|
||||
return v.value, true
|
||||
}
|
||||
|
||||
func (m *LinkedMap[K, V]) Contains(key K) bool {
|
||||
_, ok := m.collection[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
type Entry[K comparable, V any] struct {
|
||||
Key K
|
||||
Value V
|
||||
}
|
||||
|
||||
func (m *LinkedMap[K, V]) RangeKeys() <-chan K {
|
||||
defer m.check()
|
||||
res := make(chan K, len(m.collection))
|
||||
for node := m.first; node != nil; node = node.next {
|
||||
res <- node.key
|
||||
}
|
||||
close(res)
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *LinkedMap[K, V]) RangeValues() <-chan V {
|
||||
defer m.check()
|
||||
res := make(chan V, len(m.collection))
|
||||
for node := m.first; node != nil; node = node.next {
|
||||
res <- node.value
|
||||
}
|
||||
close(res)
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *LinkedMap[K, V]) RangeEntries() <-chan Entry[K, V] {
|
||||
defer m.check()
|
||||
res := make(chan Entry[K, V], len(m.collection))
|
||||
for node := m.first; node != nil; node = node.next {
|
||||
res <- Entry[K, V]{
|
||||
Key: node.key,
|
||||
Value: node.value,
|
||||
}
|
||||
}
|
||||
close(res)
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *LinkedMap[K, V]) check() {
|
||||
if !checkStatus {
|
||||
return
|
||||
}
|
||||
assert := func(c bool, text string) {
|
||||
if !c {
|
||||
panic(text)
|
||||
}
|
||||
}
|
||||
if m.first == nil {
|
||||
assert(m.last == nil, "Last should be nil")
|
||||
}
|
||||
if m.first != nil {
|
||||
assert(m.last != nil, "Last must not be nil")
|
||||
}
|
||||
if m.first == nil {
|
||||
assert(m.Len() == 0, "Len should be 0")
|
||||
}
|
||||
if m.first != nil {
|
||||
assert(m.Len() > 0, "Len should be > 0")
|
||||
count := 1
|
||||
for node := m.first; node.next != nil && count < 1000; node = node.next {
|
||||
if node.prev != nil {
|
||||
assert(node.prev.next == node, "Broken link between nodes")
|
||||
}
|
||||
count++
|
||||
}
|
||||
assert(count == m.Len(), fmt.Sprintf("Len expected %d, got %d", count, m.Len()))
|
||||
}
|
||||
}
|
130
pkg/support/collections/linkedmap_test.go
Normal file
130
pkg/support/collections/linkedmap_test.go
Normal file
@ -0,0 +1,130 @@
|
||||
package collections
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
checkStatus = true
|
||||
exitCode := m.Run()
|
||||
checkStatus = false
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func contentCheck(t *testing.T, m *LinkedMap[string, int],
|
||||
keys []string, values []int) {
|
||||
|
||||
assert.True(t, len(keys) == len(values), "input error expected keys and values differ in length")
|
||||
|
||||
// keys
|
||||
i := 0
|
||||
for key := range m.RangeKeys() {
|
||||
assert.True(t, i < len(keys), "Too many elements in map")
|
||||
assert.Equal(t, keys[i], key)
|
||||
i++
|
||||
}
|
||||
assert.Equal(t, len(keys), i)
|
||||
|
||||
// values
|
||||
i = 0
|
||||
for value := range m.RangeValues() {
|
||||
assert.True(t, i < len(values), "Too many elements in map")
|
||||
assert.Equal(t, values[i], value)
|
||||
i++
|
||||
}
|
||||
assert.Equal(t, len(values), i)
|
||||
|
||||
// Entries
|
||||
i = 0
|
||||
for entry := range m.RangeEntries() {
|
||||
assert.True(t, i < len(values), "Too many elements in map")
|
||||
assert.Equal(t, keys[i], entry.Key)
|
||||
assert.Equal(t, values[i], entry.Value)
|
||||
i++
|
||||
}
|
||||
assert.Equal(t, len(values), i)
|
||||
|
||||
// Get and Contains
|
||||
for i, key := range keys {
|
||||
v, ok := m.Get(key)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, values[i], v)
|
||||
assert.True(t, m.Contains(key))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_emptymap(t *testing.T) {
|
||||
m := NewLinkedMap[string, int]()
|
||||
contentCheck(t, m, []string{}, []int{})
|
||||
}
|
||||
|
||||
func Test_elementAddRemove(t *testing.T) {
|
||||
m := NewLinkedMap[string, int]()
|
||||
m.Put("a", 1)
|
||||
contentCheck(t, m, []string{"a"}, []int{1})
|
||||
|
||||
assert.False(t, m.Delete("b"))
|
||||
contentCheck(t, m, []string{"a"}, []int{1})
|
||||
|
||||
assert.True(t, m.Delete("a"))
|
||||
contentCheck(t, m, []string{}, []int{})
|
||||
}
|
||||
|
||||
func Test_GetContainsForElementsNotInMap(t *testing.T) {
|
||||
m := createSimpleMap(t)
|
||||
|
||||
assert.False(t, m.Contains("d"))
|
||||
val, ok := m.Get("d")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, val)
|
||||
}
|
||||
|
||||
func Test_elementRemoveBeginning(t *testing.T) {
|
||||
m := createSimpleMap(t)
|
||||
|
||||
assert.True(t, m.Delete("a"))
|
||||
contentCheck(t, m, []string{"b", "c"}, []int{2, 3})
|
||||
}
|
||||
|
||||
func Test_elementRemoveMiddle(t *testing.T) {
|
||||
m := createSimpleMap(t)
|
||||
|
||||
assert.True(t, m.Delete("b"))
|
||||
contentCheck(t, m, []string{"a", "c"}, []int{1, 3})
|
||||
}
|
||||
|
||||
func Test_elementRemoveEnd(t *testing.T) {
|
||||
m := createSimpleMap(t)
|
||||
|
||||
assert.True(t, m.Delete("c"))
|
||||
contentCheck(t, m, []string{"a", "b"}, []int{1, 2})
|
||||
}
|
||||
|
||||
func Test_addSameElementAgain(t *testing.T) {
|
||||
m := createSimpleMap(t)
|
||||
|
||||
m.Put("b", 4)
|
||||
contentCheck(t, m, []string{"a", "c", "b"}, []int{1, 3, 4})
|
||||
}
|
||||
|
||||
func createSimpleMap(t *testing.T) *LinkedMap[string, int] {
|
||||
m := NewLinkedMap[string, int]()
|
||||
m.Put("a", 1)
|
||||
m.Put("b", 2)
|
||||
m.Put("c", 3)
|
||||
contentCheck(t, m, []string{"a", "b", "c"}, []int{1, 2, 3})
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_manyElements(t *testing.T) {
|
||||
m := NewLinkedMap[string, int]()
|
||||
chars := "0123456789"
|
||||
for i := 0; i < 10000; i++ {
|
||||
m.Put(chars[i%10:i%10+1], i)
|
||||
}
|
||||
contentCheck(t, m,
|
||||
[]string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"},
|
||||
[]int{9990, 9991, 9992, 9993, 9994, 9995, 9996, 9997, 9998, 9999})
|
||||
}
|
Loading…
Reference in New Issue
Block a user