refactoring to support both windows and linux with totally different Pty code.

This commit is contained in:
Erik Brakkee 2024-07-22 23:05:51 +02:00
parent 2f40f86294
commit 1e422dd698
4 changed files with 81 additions and 128 deletions

View File

@ -4,9 +4,9 @@ import (
"bufio"
"converge/pkg/agent"
"converge/pkg/iowrappers"
"converge/pkg/terminal"
"converge/pkg/websocketutil"
"fmt"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
"github.com/gorilla/websocket"
"github.com/hashicorp/yamux"
@ -54,36 +54,20 @@ func passwordAuth(ctx ssh.Context, password string) bool {
func sshServer(hostKeyFile string, shellCommand string) *ssh.Server {
ssh.Handle(func(s ssh.Session) {
cmd := exec.Command(shellCommand)
ptyReq, winCh, isPty := s.Pty()
if isPty {
workingDirectory, _ := os.Getwd()
cmd.Env = append(os.Environ(),
fmt.Sprintf("TERM=%s", ptyReq.Term),
fmt.Sprintf("agentdir=%s", workingDirectory))
f, err := pty.Start(cmd)
env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
process, err := terminal.PtySpawner.Start(s, env, shellCommand)
if err != nil {
panic(err)
}
go func() {
for win := range winCh {
setWinsize(f, win.Width, win.Height)
}
}()
uid := int(time.Now().UnixMilli())
agent.Login(uid, s)
go func() {
io.Copy(f, s) // stdin
io.Copy(process.Pipe(), s) // stdin
}()
io.Copy(s, f) // stdout
cmd.Wait()
io.Copy(s, process.Pipe()) // stdout
process.Wait()
agent.LogOut(uid)
} else {
io.WriteString(s, "No PTY requested.\n")
s.Exit(1)
}
})
log.Println("starting ssh server, waiting for debug sessions")

View File

@ -1,98 +0,0 @@
package main
import (
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
func example() {
// Create pipes for stdin and stdout
var stdInRead, stdInWrite, stdOutRead, stdOutWrite windows.Handle
sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})), InheritHandle: 1}
err := windows.CreatePipe(&stdInRead, &stdInWrite, sa, 0)
if err != nil {
fmt.Println("Error creating stdin pipe:", err)
return
}
defer windows.CloseHandle(stdInRead)
defer windows.CloseHandle(stdInWrite)
err = windows.CreatePipe(&stdOutRead, &stdOutWrite, sa, 0)
if err != nil {
fmt.Println("Error creating stdout pipe:", err)
return
}
defer windows.CloseHandle(stdOutRead)
defer windows.CloseHandle(stdOutWrite)
// Set the pipe to non-blocking mode
mode := uint32(windows.PIPE_NOWAIT)
err = windows.SetNamedPipeHandleState(stdInWrite, &mode, nil, nil)
if err != nil {
fmt.Println("Error setting stdin pipe to non-blocking:", err)
return
}
err = windows.SetNamedPipeHandleState(stdOutRead, &mode, nil, nil)
if err != nil {
fmt.Println("Error setting stdout pipe to non-blocking:", err)
return
}
// Prepare process startup info
si := &windows.StartupInfo{
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
Flags: windows.STARTF_USESTDHANDLES,
StdInput: stdInRead,
StdOutput: stdOutWrite,
StdErr: stdOutWrite,
}
pi := &windows.ProcessInformation{}
// Create the process
cmd := "cmd.exe"
err = windows.CreateProcess(
nil,
syscall.StringToUTF16Ptr(cmd),
nil,
nil,
true,
0,
nil,
nil,
si,
pi,
)
if err != nil {
fmt.Println("Error creating process:", err)
return
}
defer windows.CloseHandle(pi.Process)
defer windows.CloseHandle(pi.Thread)
// Write to the process
message := "echo Hello, World!\r\n"
var written uint32
err = windows.WriteFile(stdInWrite, []byte(message), &written, nil)
if err != nil {
fmt.Println("Error writing to process:", err)
return
}
// Read from the process
buffer := make([]byte, 1024)
var read uint32
err = windows.ReadFile(stdOutRead, buffer, &read, nil)
if err != nil && err != windows.ERROR_NO_DATA {
fmt.Println("Error reading from process:", err)
return
}
fmt.Printf("Output: %s", buffer[:read])
// Wait for the process to finish
windows.WaitForSingleObject(pi.Process, windows.INFINITE)
}

18
pkg/terminal/process.go Normal file
View File

@ -0,0 +1,18 @@
package terminal
import (
"github.com/gliderlabs/ssh"
"io"
)
type Process interface {
Pipe() io.ReadWriter
Wait() error
}
// A function definition used to start processes
type Spawner func(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error)
func (s Spawner) Start(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) {
return s(sshSession, env, name, arg...)
}

49
pkg/terminal/pty_linux.go Normal file
View File

@ -0,0 +1,49 @@
package terminal
import (
"fmt"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
"io"
"os"
"os/exec"
"syscall"
"unsafe"
)
var PtySpawner = Spawner(func(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) {
cmd := exec.Command(name, arg...)
ptyReq, winCh, isPty := sshSession.Pty()
if !isPty {
return nil, fmt.Errorf("ssh session is not a pty")
}
cmd.Env = append(env,
fmt.Sprintf("TERM=%s", ptyReq.Term))
f, err := pty.Start(cmd)
if err != nil {
return nil, err
}
go func() {
for win := range winCh {
setWinsize(f, win.Width, win.Height)
}
}()
return ptyProcess{cmd: cmd, f: f}, nil
})
func setWinsize(f *os.File, w, h int) {
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
}
type ptyProcess struct {
cmd *exec.Cmd
f *os.File
}
func (p ptyProcess) Pipe() io.ReadWriter {
return p.f
}
func (p ptyProcess) Wait() error {
return p.cmd.Wait()
}