diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 14378d7..9082cd4 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -28,6 +28,10 @@ import ( var hostPrivateKey []byte func SftpHandler(sess ssh.Session) { + uid := int(time.Now().UnixMilli()) + agent.Login(uid, sess) + defer agent.LogOut(uid) + debugStream := io.Discard serverOptions := []sftp.ServerOption{ sftp.WithDebug(debugStream), @@ -63,12 +67,10 @@ func sshServer(hostKeyFile string, shellCommand string) *ssh.Server { } uid := int(time.Now().UnixMilli()) agent.Login(uid, s) - go func() { - io.Copy(process.Pipe(), s) // stdin - }() - io.Copy(s, process.Pipe()) // stdout - process.Wait() + iowrappers.SynchronizeStreams(process.Pipe(), s) agent.LogOut(uid) + process.Wait() + process.Wait() }) log.Println("starting ssh server, waiting for debug sessions") @@ -157,7 +159,7 @@ func main() { var service AgentService shells := []string{"bash", "sh", "ash", "ksh", "zsh", "fish", "tcsh", "csh"} if runtime.GOOS == "windows" { - shells = []string{"cmd", "powershell", "bash"} + shells = []string{"powershell", "bash"} } shell := "" diff --git a/pkg/agent/session.go b/pkg/agent/session.go index 8fbfb6b..e9819b8 100644 --- a/pkg/agent/session.go +++ b/pkg/agent/session.go @@ -126,11 +126,17 @@ func PrintMessage(sshSession ssh.Session, message string) { } func LogStatus() { - fmt := "%-20s %-20s" + fmt := "%-20s %-20s %-20s" log.Println() - log.Printf(fmt, "UID", "START_TIME") + log.Printf(fmt, "UID", "START_TIME", "TYPE") for uid, session := range state.sessions { - log.Printf(fmt, strconv.Itoa(uid), session.startTime.Format(time.DateTime)) + sessionType := session.sshSession.Subsystem() + if sessionType == "" { + sessionType = "ssh" + } + log.Printf(fmt, strconv.Itoa(uid), + session.startTime.Format(time.DateTime), + sessionType) } log.Println() } diff --git a/pkg/iowrappers/sync.go b/pkg/iowrappers/sync.go index 2f0437f..4f3af2c 100644 --- a/pkg/iowrappers/sync.go +++ b/pkg/iowrappers/sync.go @@ -27,5 +27,5 @@ func SynchronizeStreams(stream1, stream2 io.ReadWriter) { }() <-waitChannel - log.Println("Connection closed") + log.Println("SynchronizeStreams: Connection closed") } diff --git a/pkg/terminal/process.go b/pkg/terminal/process.go index a9f9a97..b15c979 100644 --- a/pkg/terminal/process.go +++ b/pkg/terminal/process.go @@ -7,6 +7,7 @@ import ( type Process interface { Pipe() io.ReadWriter + Kill() error Wait() error } diff --git a/pkg/terminal/pty_linux.go b/pkg/terminal/pty_linux.go index e933864..ef45dbd 100644 --- a/pkg/terminal/pty_linux.go +++ b/pkg/terminal/pty_linux.go @@ -44,6 +44,10 @@ type ptyProcess struct { func (p ptyProcess) Pipe() io.ReadWriter { return p.f } +func (p ptyProcess) Kill() error { + return p.cmd.Process.Kill() +} + func (p ptyProcess) Wait() error { return p.cmd.Wait() } diff --git a/pkg/terminal/pty_windows.go b/pkg/terminal/pty_windows.go index 8b534c2..5593ae0 100644 --- a/pkg/terminal/pty_windows.go +++ b/pkg/terminal/pty_windows.go @@ -21,7 +21,7 @@ var PtySpawner = Spawner(func(sshSession ssh.Session, env []string, name string, return nil, err } pid, _, err := cpty.Spawn( - "cmd.exe", + name, args, &syscall.ProcAttr{ Env: env, @@ -31,7 +31,7 @@ var PtySpawner = Spawner(func(sshSession ssh.Session, env []string, name string, cpty.Close() return nil, err } - fmt.Printf("New process with pid %d spawned\n", pid) + log.Printf("New process with pid %d spawned\n", pid) process, err := os.FindProcess(pid) if err != nil { cpty.Close() @@ -42,7 +42,7 @@ var PtySpawner = Spawner(func(sshSession ssh.Session, env []string, name string, for win := range winCh { err = cpty.Resize(uint16(win.Width), uint16(win.Height)) if err != nil { - log.Printf("Feiled to resize terminal to %d x %d", win.Width, win.Height) + log.Printf("Failed to resize terminal to %d x %d", win.Width, win.Height) } } }() @@ -61,13 +61,18 @@ func (proc ptyProcess) Read(p []byte) (n int, err error) { return proc.cpty.OutPipe().Read(p) } func (proc ptyProcess) Write(p []byte) (n int, err error) { - return proc.Write(p) + uintn, err := proc.cpty.Write(p) + return int(uintn), err } func (p ptyProcess) Pipe() io.ReadWriter { return p } +func (p ptyProcess) Kill() error { + return p.process.Kill() +} + func (p ptyProcess) Wait() error { defer p.cpty.Close() ps, err := p.process.Wait()