refactoring to support both windows and linux with totally different Pty code.
This commit is contained in:
parent
1ebee30c8c
commit
3bd0e3f3e1
@ -4,9 +4,9 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"converge/pkg/agent"
|
"converge/pkg/agent"
|
||||||
"converge/pkg/iowrappers"
|
"converge/pkg/iowrappers"
|
||||||
|
"converge/pkg/terminal"
|
||||||
"converge/pkg/websocketutil"
|
"converge/pkg/websocketutil"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/creack/pty"
|
|
||||||
"github.com/gliderlabs/ssh"
|
"github.com/gliderlabs/ssh"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/hashicorp/yamux"
|
"github.com/hashicorp/yamux"
|
||||||
@ -54,36 +54,20 @@ func passwordAuth(ctx ssh.Context, password string) bool {
|
|||||||
|
|
||||||
func sshServer(hostKeyFile string, shellCommand string) *ssh.Server {
|
func sshServer(hostKeyFile string, shellCommand string) *ssh.Server {
|
||||||
ssh.Handle(func(s ssh.Session) {
|
ssh.Handle(func(s ssh.Session) {
|
||||||
|
|
||||||
cmd := exec.Command(shellCommand)
|
|
||||||
ptyReq, winCh, isPty := s.Pty()
|
|
||||||
if isPty {
|
|
||||||
workingDirectory, _ := os.Getwd()
|
workingDirectory, _ := os.Getwd()
|
||||||
cmd.Env = append(os.Environ(),
|
env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory))
|
||||||
fmt.Sprintf("TERM=%s", ptyReq.Term),
|
process, err := terminal.PtySpawner.Start(s, env, shellCommand)
|
||||||
fmt.Sprintf("agentdir=%s", workingDirectory))
|
|
||||||
f, err := pty.Start(cmd)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
go func() {
|
|
||||||
for win := range winCh {
|
|
||||||
setWinsize(f, win.Width, win.Height)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
uid := int(time.Now().UnixMilli())
|
uid := int(time.Now().UnixMilli())
|
||||||
agent.Login(uid, s)
|
agent.Login(uid, s)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
io.Copy(f, s) // stdin
|
io.Copy(process.Pipe(), s) // stdin
|
||||||
}()
|
}()
|
||||||
io.Copy(s, f) // stdout
|
io.Copy(s, process.Pipe()) // stdout
|
||||||
cmd.Wait()
|
process.Wait()
|
||||||
agent.LogOut(uid)
|
agent.LogOut(uid)
|
||||||
} else {
|
|
||||||
io.WriteString(s, "No PTY requested.\n")
|
|
||||||
s.Exit(1)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
log.Println("starting ssh server, waiting for debug sessions")
|
log.Println("starting ssh server, waiting for debug sessions")
|
||||||
|
@ -1,98 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
func example() {
|
|
||||||
// Create pipes for stdin and stdout
|
|
||||||
var stdInRead, stdInWrite, stdOutRead, stdOutWrite windows.Handle
|
|
||||||
sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})), InheritHandle: 1}
|
|
||||||
|
|
||||||
err := windows.CreatePipe(&stdInRead, &stdInWrite, sa, 0)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Error creating stdin pipe:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer windows.CloseHandle(stdInRead)
|
|
||||||
defer windows.CloseHandle(stdInWrite)
|
|
||||||
|
|
||||||
err = windows.CreatePipe(&stdOutRead, &stdOutWrite, sa, 0)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Error creating stdout pipe:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer windows.CloseHandle(stdOutRead)
|
|
||||||
defer windows.CloseHandle(stdOutWrite)
|
|
||||||
|
|
||||||
// Set the pipe to non-blocking mode
|
|
||||||
mode := uint32(windows.PIPE_NOWAIT)
|
|
||||||
err = windows.SetNamedPipeHandleState(stdInWrite, &mode, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Error setting stdin pipe to non-blocking:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = windows.SetNamedPipeHandleState(stdOutRead, &mode, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Error setting stdout pipe to non-blocking:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare process startup info
|
|
||||||
si := &windows.StartupInfo{
|
|
||||||
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
|
|
||||||
Flags: windows.STARTF_USESTDHANDLES,
|
|
||||||
StdInput: stdInRead,
|
|
||||||
StdOutput: stdOutWrite,
|
|
||||||
StdErr: stdOutWrite,
|
|
||||||
}
|
|
||||||
pi := &windows.ProcessInformation{}
|
|
||||||
|
|
||||||
// Create the process
|
|
||||||
cmd := "cmd.exe"
|
|
||||||
err = windows.CreateProcess(
|
|
||||||
nil,
|
|
||||||
syscall.StringToUTF16Ptr(cmd),
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
true,
|
|
||||||
0,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
si,
|
|
||||||
pi,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Error creating process:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer windows.CloseHandle(pi.Process)
|
|
||||||
defer windows.CloseHandle(pi.Thread)
|
|
||||||
|
|
||||||
// Write to the process
|
|
||||||
message := "echo Hello, World!\r\n"
|
|
||||||
var written uint32
|
|
||||||
err = windows.WriteFile(stdInWrite, []byte(message), &written, nil)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Error writing to process:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read from the process
|
|
||||||
buffer := make([]byte, 1024)
|
|
||||||
var read uint32
|
|
||||||
err = windows.ReadFile(stdOutRead, buffer, &read, nil)
|
|
||||||
if err != nil && err != windows.ERROR_NO_DATA {
|
|
||||||
fmt.Println("Error reading from process:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Output: %s", buffer[:read])
|
|
||||||
|
|
||||||
// Wait for the process to finish
|
|
||||||
windows.WaitForSingleObject(pi.Process, windows.INFINITE)
|
|
||||||
}
|
|
18
pkg/terminal/process.go
Normal file
18
pkg/terminal/process.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package terminal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Process interface {
|
||||||
|
Pipe() io.ReadWriter
|
||||||
|
Wait() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// A function definition used to start processes
|
||||||
|
type Spawner func(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error)
|
||||||
|
|
||||||
|
func (s Spawner) Start(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) {
|
||||||
|
return s(sshSession, env, name, arg...)
|
||||||
|
}
|
49
pkg/terminal/pty_linux.go
Normal file
49
pkg/terminal/pty_linux.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package terminal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/creack/pty"
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var PtySpawner = Spawner(func(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) {
|
||||||
|
cmd := exec.Command(name, arg...)
|
||||||
|
ptyReq, winCh, isPty := sshSession.Pty()
|
||||||
|
if !isPty {
|
||||||
|
return nil, fmt.Errorf("ssh session is not a pty")
|
||||||
|
}
|
||||||
|
cmd.Env = append(env,
|
||||||
|
fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||||
|
f, err := pty.Start(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for win := range winCh {
|
||||||
|
setWinsize(f, win.Width, win.Height)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return ptyProcess{cmd: cmd, f: f}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
func setWinsize(f *os.File, w, h int) {
|
||||||
|
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
|
||||||
|
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
|
||||||
|
}
|
||||||
|
|
||||||
|
type ptyProcess struct {
|
||||||
|
cmd *exec.Cmd
|
||||||
|
f *os.File
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p ptyProcess) Pipe() io.ReadWriter {
|
||||||
|
return p.f
|
||||||
|
}
|
||||||
|
func (p ptyProcess) Wait() error {
|
||||||
|
return p.cmd.Wait()
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user