refactoring to support both windows and linux with totally different Pty code.
This commit is contained in:
parent
1ebee30c8c
commit
3bd0e3f3e1
@ -4,9 +4,9 @@ import (
|
||||
"bufio"
|
||||
"converge/pkg/agent"
|
||||
"converge/pkg/iowrappers"
|
||||
"converge/pkg/terminal"
|
||||
"converge/pkg/websocketutil"
|
||||
"fmt"
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/hashicorp/yamux"
|
||||
@ -54,36 +54,20 @@ func passwordAuth(ctx ssh.Context, password string) bool {
|
||||
|
||||
func sshServer(hostKeyFile string, shellCommand string) *ssh.Server {
|
||||
ssh.Handle(func(s ssh.Session) {
|
||||
|
||||
cmd := exec.Command(shellCommand)
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
if isPty {
|
||||
workingDirectory, _ := os.Getwd()
|
||||
cmd.Env = append(os.Environ(),
|
||||
fmt.Sprintf("TERM=%s", ptyReq.Term),
|
||||
fmt.Sprintf("agentdir=%s", workingDirectory))
|
||||
f, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
setWinsize(f, win.Width, win.Height)
|
||||
}
|
||||
}()
|
||||
uid := int(time.Now().UnixMilli())
|
||||
agent.Login(uid, s)
|
||||
|
||||
go func() {
|
||||
io.Copy(f, s) // stdin
|
||||
}()
|
||||
io.Copy(s, f) // stdout
|
||||
cmd.Wait()
|
||||
agent.LogOut(uid)
|
||||
} else {
|
||||
io.WriteString(s, "No PTY requested.\n")
|
||||
s.Exit(1)
|
||||
workingDirectory, _ := os.Getwd()
|
||||
env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
|
||||
process, err := terminal.PtySpawner.Start(s, env, shellCommand)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
uid := int(time.Now().UnixMilli())
|
||||
agent.Login(uid, s)
|
||||
go func() {
|
||||
io.Copy(process.Pipe(), s) // stdin
|
||||
}()
|
||||
io.Copy(s, process.Pipe()) // stdout
|
||||
process.Wait()
|
||||
agent.LogOut(uid)
|
||||
})
|
||||
|
||||
log.Println("starting ssh server, waiting for debug sessions")
|
||||
|
@ -1,98 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func example() {
|
||||
// Create pipes for stdin and stdout
|
||||
var stdInRead, stdInWrite, stdOutRead, stdOutWrite windows.Handle
|
||||
sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})), InheritHandle: 1}
|
||||
|
||||
err := windows.CreatePipe(&stdInRead, &stdInWrite, sa, 0)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating stdin pipe:", err)
|
||||
return
|
||||
}
|
||||
defer windows.CloseHandle(stdInRead)
|
||||
defer windows.CloseHandle(stdInWrite)
|
||||
|
||||
err = windows.CreatePipe(&stdOutRead, &stdOutWrite, sa, 0)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating stdout pipe:", err)
|
||||
return
|
||||
}
|
||||
defer windows.CloseHandle(stdOutRead)
|
||||
defer windows.CloseHandle(stdOutWrite)
|
||||
|
||||
// Set the pipe to non-blocking mode
|
||||
mode := uint32(windows.PIPE_NOWAIT)
|
||||
err = windows.SetNamedPipeHandleState(stdInWrite, &mode, nil, nil)
|
||||
if err != nil {
|
||||
fmt.Println("Error setting stdin pipe to non-blocking:", err)
|
||||
return
|
||||
}
|
||||
err = windows.SetNamedPipeHandleState(stdOutRead, &mode, nil, nil)
|
||||
if err != nil {
|
||||
fmt.Println("Error setting stdout pipe to non-blocking:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare process startup info
|
||||
si := &windows.StartupInfo{
|
||||
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
|
||||
Flags: windows.STARTF_USESTDHANDLES,
|
||||
StdInput: stdInRead,
|
||||
StdOutput: stdOutWrite,
|
||||
StdErr: stdOutWrite,
|
||||
}
|
||||
pi := &windows.ProcessInformation{}
|
||||
|
||||
// Create the process
|
||||
cmd := "cmd.exe"
|
||||
err = windows.CreateProcess(
|
||||
nil,
|
||||
syscall.StringToUTF16Ptr(cmd),
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
0,
|
||||
nil,
|
||||
nil,
|
||||
si,
|
||||
pi,
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating process:", err)
|
||||
return
|
||||
}
|
||||
defer windows.CloseHandle(pi.Process)
|
||||
defer windows.CloseHandle(pi.Thread)
|
||||
|
||||
// Write to the process
|
||||
message := "echo Hello, World!\r\n"
|
||||
var written uint32
|
||||
err = windows.WriteFile(stdInWrite, []byte(message), &written, nil)
|
||||
if err != nil {
|
||||
fmt.Println("Error writing to process:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Read from the process
|
||||
buffer := make([]byte, 1024)
|
||||
var read uint32
|
||||
err = windows.ReadFile(stdOutRead, buffer, &read, nil)
|
||||
if err != nil && err != windows.ERROR_NO_DATA {
|
||||
fmt.Println("Error reading from process:", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Output: %s", buffer[:read])
|
||||
|
||||
// Wait for the process to finish
|
||||
windows.WaitForSingleObject(pi.Process, windows.INFINITE)
|
||||
}
|
18
pkg/terminal/process.go
Normal file
18
pkg/terminal/process.go
Normal file
@ -0,0 +1,18 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"github.com/gliderlabs/ssh"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Process interface {
|
||||
Pipe() io.ReadWriter
|
||||
Wait() error
|
||||
}
|
||||
|
||||
// A function definition used to start processes
|
||||
type Spawner func(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error)
|
||||
|
||||
func (s Spawner) Start(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) {
|
||||
return s(sshSession, env, name, arg...)
|
||||
}
|
49
pkg/terminal/pty_linux.go
Normal file
49
pkg/terminal/pty_linux.go
Normal file
@ -0,0 +1,49 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var PtySpawner = Spawner(func(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) {
|
||||
cmd := exec.Command(name, arg...)
|
||||
ptyReq, winCh, isPty := sshSession.Pty()
|
||||
if !isPty {
|
||||
return nil, fmt.Errorf("ssh session is not a pty")
|
||||
}
|
||||
cmd.Env = append(env,
|
||||
fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
f, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
setWinsize(f, win.Width, win.Height)
|
||||
}
|
||||
}()
|
||||
return ptyProcess{cmd: cmd, f: f}, nil
|
||||
})
|
||||
|
||||
func setWinsize(f *os.File, w, h int) {
|
||||
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
|
||||
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
|
||||
}
|
||||
|
||||
type ptyProcess struct {
|
||||
cmd *exec.Cmd
|
||||
f *os.File
|
||||
}
|
||||
|
||||
func (p ptyProcess) Pipe() io.ReadWriter {
|
||||
return p.f
|
||||
}
|
||||
func (p ptyProcess) Wait() error {
|
||||
return p.cmd.Wait()
|
||||
}
|
Loading…
Reference in New Issue
Block a user