diff --git a/cmd/agent/agent.go b/cmd/agent/agent.go index 878845a..c9b6c7a 100755 --- a/cmd/agent/agent.go +++ b/cmd/agent/agent.go @@ -4,9 +4,9 @@ import ( "bufio" "converge/pkg/agent" "converge/pkg/iowrappers" + "converge/pkg/terminal" "converge/pkg/websocketutil" "fmt" - "github.com/creack/pty" "github.com/gliderlabs/ssh" "github.com/gorilla/websocket" "github.com/hashicorp/yamux" @@ -54,36 +54,20 @@ func passwordAuth(ctx ssh.Context, password string) bool { func sshServer(hostKeyFile string, shellCommand string) *ssh.Server { ssh.Handle(func(s ssh.Session) { - - cmd := exec.Command(shellCommand) - ptyReq, winCh, isPty := s.Pty() - if isPty { - workingDirectory, _ := os.Getwd() - cmd.Env = append(os.Environ(), - fmt.Sprintf("TERM=%s", ptyReq.Term), - fmt.Sprintf("agentdir=%s", workingDirectory)) - f, err := pty.Start(cmd) - if err != nil { - panic(err) - } - go func() { - for win := range winCh { - setWinsize(f, win.Width, win.Height) - } - }() - uid := int(time.Now().UnixMilli()) - agent.Login(uid, s) - - go func() { - io.Copy(f, s) // stdin - }() - io.Copy(s, f) // stdout - cmd.Wait() - agent.LogOut(uid) - } else { - io.WriteString(s, "No PTY requested.\n") - s.Exit(1) + workingDirectory, _ := os.Getwd() + env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory)) + process, err := terminal.PtySpawner.Start(s, env, shellCommand) + if err != nil { + panic(err) } + uid := int(time.Now().UnixMilli()) + agent.Login(uid, s) + go func() { + io.Copy(process.Pipe(), s) // stdin + }() + io.Copy(s, process.Pipe()) // stdout + process.Wait() + agent.LogOut(uid) }) log.Println("starting ssh server, waiting for debug sessions") diff --git a/cmd/agent/open_process_windows.go b/cmd/agent/open_process_windows.go deleted file mode 100644 index 0378ba4..0000000 --- a/cmd/agent/open_process_windows.go +++ /dev/null @@ -1,98 +0,0 @@ -package main - -import ( - "fmt" - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -func example() { - // Create pipes for stdin and stdout - var stdInRead, stdInWrite, stdOutRead, stdOutWrite windows.Handle - sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})), InheritHandle: 1} - - err := windows.CreatePipe(&stdInRead, &stdInWrite, sa, 0) - if err != nil { - fmt.Println("Error creating stdin pipe:", err) - return - } - defer windows.CloseHandle(stdInRead) - defer windows.CloseHandle(stdInWrite) - - err = windows.CreatePipe(&stdOutRead, &stdOutWrite, sa, 0) - if err != nil { - fmt.Println("Error creating stdout pipe:", err) - return - } - defer windows.CloseHandle(stdOutRead) - defer windows.CloseHandle(stdOutWrite) - - // Set the pipe to non-blocking mode - mode := uint32(windows.PIPE_NOWAIT) - err = windows.SetNamedPipeHandleState(stdInWrite, &mode, nil, nil) - if err != nil { - fmt.Println("Error setting stdin pipe to non-blocking:", err) - return - } - err = windows.SetNamedPipeHandleState(stdOutRead, &mode, nil, nil) - if err != nil { - fmt.Println("Error setting stdout pipe to non-blocking:", err) - return - } - - // Prepare process startup info - si := &windows.StartupInfo{ - Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})), - Flags: windows.STARTF_USESTDHANDLES, - StdInput: stdInRead, - StdOutput: stdOutWrite, - StdErr: stdOutWrite, - } - pi := &windows.ProcessInformation{} - - // Create the process - cmd := "cmd.exe" - err = windows.CreateProcess( - nil, - syscall.StringToUTF16Ptr(cmd), - nil, - nil, - true, - 0, - nil, - nil, - si, - pi, - ) - if err != nil { - fmt.Println("Error creating process:", err) - return - } - defer windows.CloseHandle(pi.Process) - defer windows.CloseHandle(pi.Thread) - - // Write to the process - message := "echo Hello, World!\r\n" - var written uint32 - err = windows.WriteFile(stdInWrite, []byte(message), &written, nil) - if err != nil { - fmt.Println("Error writing to process:", err) - return - } - - // Read from the process - buffer := make([]byte, 1024) - var read uint32 - err = windows.ReadFile(stdOutRead, buffer, &read, nil) - if err != nil && err != windows.ERROR_NO_DATA { - fmt.Println("Error reading from process:", err) - return - } - - fmt.Printf("Output: %s", buffer[:read]) - - // Wait for the process to finish - windows.WaitForSingleObject(pi.Process, windows.INFINITE) -} diff --git a/pkg/terminal/process.go b/pkg/terminal/process.go new file mode 100644 index 0000000..a9f9a97 --- /dev/null +++ b/pkg/terminal/process.go @@ -0,0 +1,18 @@ +package terminal + +import ( + "github.com/gliderlabs/ssh" + "io" +) + +type Process interface { + Pipe() io.ReadWriter + Wait() error +} + +// A function definition used to start processes +type Spawner func(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) + +func (s Spawner) Start(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) { + return s(sshSession, env, name, arg...) +} diff --git a/pkg/terminal/pty_linux.go b/pkg/terminal/pty_linux.go new file mode 100644 index 0000000..e933864 --- /dev/null +++ b/pkg/terminal/pty_linux.go @@ -0,0 +1,49 @@ +package terminal + +import ( + "fmt" + "github.com/creack/pty" + "github.com/gliderlabs/ssh" + "io" + "os" + "os/exec" + "syscall" + "unsafe" +) + +var PtySpawner = Spawner(func(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) { + cmd := exec.Command(name, arg...) + ptyReq, winCh, isPty := sshSession.Pty() + if !isPty { + return nil, fmt.Errorf("ssh session is not a pty") + } + cmd.Env = append(env, + fmt.Sprintf("TERM=%s", ptyReq.Term)) + f, err := pty.Start(cmd) + if err != nil { + return nil, err + } + go func() { + for win := range winCh { + setWinsize(f, win.Width, win.Height) + } + }() + return ptyProcess{cmd: cmd, f: f}, nil +}) + +func setWinsize(f *os.File, w, h int) { + syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), + uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) +} + +type ptyProcess struct { + cmd *exec.Cmd + f *os.File +} + +func (p ptyProcess) Pipe() io.ReadWriter { + return p.f +} +func (p ptyProcess) Wait() error { + return p.cmd.Wait() +}