diff --git a/pkg/models/state.go b/pkg/models/state.go index 0b30a4d..8d0c269 100644 --- a/pkg/models/state.go +++ b/pkg/models/state.go @@ -66,3 +66,15 @@ func NewState() *State { Clients: collections.NewLinkedMap[ClientGuid, *Client](), } } + +// for copy on write +func (state *State) Copy() *State { + res := NewState() + for entry := range state.Agents.RangeEntries() { + res.Agents.Put(entry.Key, entry.Value) + } + for entry := range state.Clients.RangeEntries() { + res.Clients.Put(entry.Key, entry.Value) + } + return res +} diff --git a/pkg/server/admin/admin.go b/pkg/server/admin/admin.go index 3c3bd79..bef5de8 100644 --- a/pkg/server/admin/admin.go +++ b/pkg/server/admin/admin.go @@ -16,7 +16,7 @@ import ( ) type agentConnection struct { - models.Agent + Info *models.Agent // server session CommChannel comms.CommChannel } @@ -25,35 +25,37 @@ var agentIdGenerator = concurrency.NewAtomicCounter() var clientIdGenerator = concurrency.NewAtomicCounter() type ClientConnection struct { - models.Client + Info *models.Client agentConnection net.Conn clientConnection iowrappers2.ReadWriteAddrCloser } func newAgent(commChannel comms.CommChannel, publicId models.RendezVousId, agentInfo comms.EnvironmentInfo) *agentConnection { + agent := models.Agent{ + Guid: models.AgentGuid(strconv.Itoa(rand.Int())), + RemoteAddr: models.RemoteAddr(commChannel.Session.RemoteAddr().String()), + PublicId: publicId, + StartTime: time.Now(), + EnvironmentInfo: agentInfo, + } return &agentConnection{ - Agent: models.Agent{ - Guid: models.AgentGuid(strconv.Itoa(rand.Int())), - RemoteAddr: models.RemoteAddr(commChannel.Session.RemoteAddr().String()), - PublicId: publicId, - StartTime: time.Now(), - EnvironmentInfo: agentInfo, - }, + Info: &agent, CommChannel: commChannel, } } func newClient(publicId models.RendezVousId, clientConn iowrappers2.ReadWriteAddrCloser, agentConn net.Conn, agentGuid models.AgentGuid) *ClientConnection { + client := models.Client{ + Guid: models.ClientGuid(strconv.Itoa(rand.Int())), + RemoteAddr: models.RemoteAddr(clientConn.RemoteAddr().String()), + PublicId: publicId, + AgentGuid: agentGuid, + ClientId: models.ClientId(strconv.Itoa(clientIdGenerator.IncrementAndGet())), + StartTime: time.Now(), + } return &ClientConnection{ - Client: models.Client{ - Guid: models.ClientGuid(strconv.Itoa(rand.Int())), - RemoteAddr: models.RemoteAddr(clientConn.RemoteAddr().String()), - PublicId: publicId, - AgentGuid: agentGuid, - ClientId: models.ClientId(strconv.Itoa(clientIdGenerator.IncrementAndGet())), - StartTime: time.Now(), - }, + Info: &client, agentConnection: agentConn, clientConnection: clientConn, } @@ -65,7 +67,13 @@ func (match *ClientConnection) Synchronize() { type Admin struct { // map of public id to agent - mutex sync.Mutex + mutex sync.Mutex + // for reporting state to webclients and prometheus and also used for + // logging the state. This uses copy-on-write. Every time an agent or + // clinet is added or removed a copy is made. + state *models.State + + // TODO: use linked map for both of these agents map[models.RendezVousId]*agentConnection clients []*ClientConnection } @@ -73,6 +81,7 @@ type Admin struct { func NewAdmin() *Admin { return &Admin{ mutex: sync.Mutex{}, + state: models.NewState(), agents: make(map[models.RendezVousId]*agentConnection), clients: make([]*ClientConnection, 0), // not strictly needed } @@ -81,30 +90,17 @@ func NewAdmin() *Admin { func (admin *Admin) CreateNotifification() *models.State { admin.mutex.Lock() defer admin.mutex.Unlock() - state := models.NewState() - for _, agent := range admin.agents { - state.Agents.Put(agent.Guid, &agent.Agent) - } - for _, client := range admin.clients { - state.Clients.Put(client.Guid, &client.Client) - } - return state + return admin.state } func (admin *Admin) getFreeId(publicId models.RendezVousId) (models.RendezVousId, error) { - usedIds := make(map[models.RendezVousId]bool) - for _, agent := range admin.agents { - usedIds[agent.PublicId] = true - } - if !usedIds[publicId] { + if admin.agents[publicId] == nil { return publicId, nil } - if usedIds[publicId] { - for i := 0; i < 100; i++ { - candidate := string(publicId) + "-" + strconv.Itoa(i) - if !usedIds[models.RendezVousId(candidate)] { - return models.RendezVousId(candidate), nil - } + for i := 0; i < 100; i++ { + candidate := models.RendezVousId(string(publicId) + "-" + strconv.Itoa(i)) + if admin.agents[candidate] == nil { + return candidate, nil } } return "", fmt.Errorf("Could not allocate agent id based on requested public id '%s'", publicId) @@ -132,8 +128,8 @@ func (admin *Admin) AddAgent(publicId models.RendezVousId, agentInfo comms.Envir Message: err.Error(), }) } - agent := admin.agents[publicId] - if agent != nil { + agentCheck := admin.agents[publicId] + if agentCheck != nil { return nil, fmt.Errorf("SHOULD NEVER GET HERE!!!, A different agent with same PublicId '%s' already registered", publicId) } @@ -142,7 +138,10 @@ func (admin *Admin) AddAgent(publicId models.RendezVousId, agentInfo comms.Envir if err != nil { return nil, err } - agent = newAgent(commChannel, publicId, agentInfo) + agent := newAgent(commChannel, publicId, agentInfo) + + admin.state = admin.state.Copy() + admin.state.Agents.Put(agent.Info.Guid, agent.Info) admin.agents[publicId] = agent return agent, nil } @@ -165,17 +164,19 @@ func (admin *Admin) AddClient(publicId models.RendezVousId, clientConn iowrapper log.Println("Sending connection information to agent") - client := newClient(publicId, clientConn, agentConn, agent.Guid) + client := newClient(publicId, clientConn, agentConn, agent.Info.Guid) // Before using this connection for SSH we use it to send client metadata to the // agent err = comms.SendClientInfo(agentConn, comms.ClientInfo{ - ClientId: string(client.ClientId), + ClientId: string(client.Info.ClientId), }) if err != nil { return nil, err } + admin.state = admin.state.Copy() + admin.state.Clients.Put(client.Info.Guid, client.Info) admin.clients = append(admin.clients, client) return client, nil } @@ -205,6 +206,8 @@ func (admin *Admin) RemoveAgent(publicId models.RendezVousId) error { if err != nil { log.Printf("Could not close yamux client session for '%s'\n", publicId) } + admin.state = admin.state.Copy() + admin.state.Agents.Delete(agent.Info.Guid) delete(admin.agents, publicId) return nil } @@ -213,14 +216,16 @@ func (admin *Admin) RemoveClient(client *ClientConnection) error { admin.mutex.Lock() defer admin.mutex.Unlock() - log.Printf("Removing client: '%s' created at %s\n", client.ClientId, - client.StartTime.Format(time.DateTime)) + log.Printf("Removing client: '%s' created at %s\n", client.Info.Guid, + client.Info.StartTime.Format(time.DateTime)) // try to explicitly close connection to the agent. _ = client.agentConnection.Close() _ = client.clientConnection.Close() for i, _client := range admin.clients { - if _client.ClientId == client.ClientId { + if _client.Info.ClientId == client.Info.ClientId { + admin.state = admin.state.Copy() + admin.state.Clients.Delete(client.Info.Guid) admin.clients = append(admin.clients[:i], admin.clients[i+1:]...) break } @@ -231,7 +236,7 @@ func (admin *Admin) RemoveClient(client *ClientConnection) error { func (admin *Admin) SetSessionType(clientId models.ClientId, sessionType models.SessionType) { admin.mutex.Lock() defer admin.mutex.Unlock() - for _, client := range admin.clients { + for client := range admin.state.Clients.RangeValues() { if client.ClientId == clientId { client.SessionType = sessionType break diff --git a/pkg/server/matchmaker/matchmaker.go b/pkg/server/matchmaker/matchmaker.go index 75f9044..ba8c292 100644 --- a/pkg/server/matchmaker/matchmaker.go +++ b/pkg/server/matchmaker/matchmaker.go @@ -39,7 +39,7 @@ func (converge *MatchMaker) Register(publicId models.RendezVousId, conn io.ReadW if err != nil { return err } - publicId = agent.PublicId + publicId = agent.Info.PublicId defer func() { converge.admin.RemoveAgent(publicId) converge.logStatus() @@ -48,7 +48,7 @@ func (converge *MatchMaker) Register(publicId models.RendezVousId, conn io.ReadW go func() { comms.ListenForAgentEvents(agent.CommChannel.SideChannel, func(info comms.EnvironmentInfo) { - agent.EnvironmentInfo = info + agent.Info.EnvironmentInfo = info converge.logStatus() }, func(session comms.SessionInfo) { @@ -56,7 +56,7 @@ func (converge *MatchMaker) Register(publicId models.RendezVousId, conn io.ReadW converge.admin.SetSessionType(models.ClientId(session.ClientId), models.SessionType(session.SessionType)) }, func(expiry comms.ExpiryTimeUpdate) { - agent.SetExpiryTime(expiry.ExpiryTime) + agent.Info.SetExpiryTime(expiry.ExpiryTime) converge.logStatus() }) }() @@ -114,7 +114,7 @@ func (converge *MatchMaker) Connect(wsProxyMode bool, publicId models.RendezVous if err != nil { return fmt.Errorf("Error receiving environment info from client: %v", err) } - client.EnvironmentInfo = clientEnvironment + client.Info.EnvironmentInfo = clientEnvironment } converge.logStatus() client.Synchronize()