diff --git a/cmd/converge/converge.go b/cmd/converge/converge.go index 8a592a4..9000096 100644 --- a/cmd/converge/converge.go +++ b/cmd/converge/converge.go @@ -1,6 +1,7 @@ package main import ( + "context" "converge/pkg/server/converge" "converge/pkg/support/websocketutil" "fmt" @@ -134,13 +135,14 @@ func main() { // for the web browser getting live status updates. sessionService := websocketutil.WebSocketService{ Handler: func(w http.ResponseWriter, r *http.Request, conn net.Conn) { - websession := websessions.NewSession(conn) + ctx, cancel := context.WithCancel(context.Background()) + websession := websessions.NewSession(conn, ctx) defer websessions.SessionClosed(websession) location, err := converge.GetUserLocation(r) if err != nil { panic(err) } - websession.WriteNotifications(location) + websession.WriteNotifications(location, cancel) }, Text: true, } diff --git a/pkg/server/converge/websessions.go b/pkg/server/converge/websessions.go index a868204..504bafb 100644 --- a/pkg/server/converge/websessions.go +++ b/pkg/server/converge/websessions.go @@ -20,6 +20,7 @@ type WebSessions struct { type WebSession struct { notifications chan *models.State conn net.Conn + ctx context.Context } func NewWebSessions(notifications chan *models.State) *WebSessions { @@ -41,16 +42,22 @@ func (sessions *WebSessions) notifyWebSessions(notification *models.State) { defer sessions.mutex.Unlock() sessions.lastNotification = notification for session, _ := range sessions.sessions { - session.notifications <- notification + select { + case <-session.ctx.Done(): + // session is closed, will be removed at higher level when session is done. + case session.notifications <- notification: + // Sent notification + } } } -func (sessions *WebSessions) NewSession(wsConnection net.Conn) *WebSession { +func (sessions *WebSessions) NewSession(wsConnection net.Conn, ctx context.Context) *WebSession { sessions.mutex.Lock() defer sessions.mutex.Unlock() session := &WebSession{ notifications: make(chan *models.State, 10), conn: wsConnection, + ctx: ctx, } sessions.sessions[session] = true sessions.logSessions() @@ -70,9 +77,11 @@ func GetUserLocation(r *http.Request) (*time.Location, error) { return time.LoadLocation(tzName) } -func (session *WebSession) WriteNotifications(location *time.Location) { +func (session *WebSession) WriteNotifications(location *time.Location, cancel context.CancelFunc) { timer := time.NewTicker(10 * time.Second) defer timer.Stop() + // if for some reason we cannot send notifications to the web client then the context is canceled. + defer cancel() for { select { case notification, ok := <-session.notifications: