diff --git a/pkg/agent/session.go b/pkg/agent/session.go index ef2fd5c..600b61d 100644 --- a/pkg/agent/session.go +++ b/pkg/agent/session.go @@ -49,7 +49,8 @@ type AgentState struct { // map of unique session id to a session sessions map[int]*AgentSession - agentUsed bool + lastUserLoginTime time.Time + agentUsed bool } type AgentSession struct { @@ -88,9 +89,14 @@ func ConfigureAgent(advanceWarningTime, agentExpiryTime, tickerInterval time.Dur tickerInterval: tickerInterval, ticker: time.NewTicker(tickerInterval), sessions: make(map[int]*AgentSession), - agentUsed: false, + + lastUserLoginTime: time.Time{}, + agentUsed: false, } + log.Printf("Agent expires at %s", + state.expiryTime(holdFilename).Format(time.DateTime)) + go func() { for { <-state.ticker.C @@ -184,6 +190,7 @@ func login(sessionId int, sshSession ssh.Session) { sshSession: sshSession, } state.sessions[sessionId] = &agentSession + state.lastUserLoginTime = time.Now() state.agentUsed = true logStatus() @@ -254,11 +261,19 @@ func fileExists(filename string) bool { } func (state *AgentState) expiryTime(filename string) time.Time { - stats, err := os.Stat(filename) - if err != nil { + if !state.agentUsed { return state.startTime.Add(state.agentExpriryTime) } - return stats.ModTime().Add(state.agentExpriryTime) + expiryTime := time.Time{} + stats, err := os.Stat(filename) + if err == nil { + expiryTime = stats.ModTime().Add(state.agentExpriryTime) + } + userLoginBaseExpiryTime := state.lastUserLoginTime.Add(state.agentExpriryTime) + if userLoginBaseExpiryTime.After(expiryTime) { + expiryTime = userLoginBaseExpiryTime + } + return expiryTime } func holdFileChange() {