refactoring to support both windows and linux with totally different Pty code.
This commit is contained in:
		
							parent
							
								
									1ebee30c8c
								
							
						
					
					
						commit
						3bd0e3f3e1
					
				| @ -4,9 +4,9 @@ import ( | ||||
| 	"bufio" | ||||
| 	"converge/pkg/agent" | ||||
| 	"converge/pkg/iowrappers" | ||||
| 	"converge/pkg/terminal" | ||||
| 	"converge/pkg/websocketutil" | ||||
| 	"fmt" | ||||
| 	"github.com/creack/pty" | ||||
| 	"github.com/gliderlabs/ssh" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/hashicorp/yamux" | ||||
| @ -54,36 +54,20 @@ func passwordAuth(ctx ssh.Context, password string) bool { | ||||
| 
 | ||||
| func sshServer(hostKeyFile string, shellCommand string) *ssh.Server { | ||||
| 	ssh.Handle(func(s ssh.Session) { | ||||
| 
 | ||||
| 		cmd := exec.Command(shellCommand) | ||||
| 		ptyReq, winCh, isPty := s.Pty() | ||||
| 		if isPty { | ||||
| 			workingDirectory, _ := os.Getwd() | ||||
| 			cmd.Env = append(os.Environ(), | ||||
| 				fmt.Sprintf("TERM=%s", ptyReq.Term), | ||||
| 				fmt.Sprintf("agentdir=%s", workingDirectory)) | ||||
| 			f, err := pty.Start(cmd) | ||||
| 			if err != nil { | ||||
| 				panic(err) | ||||
| 			} | ||||
| 			go func() { | ||||
| 				for win := range winCh { | ||||
| 					setWinsize(f, win.Width, win.Height) | ||||
| 				} | ||||
| 			}() | ||||
| 			uid := int(time.Now().UnixMilli()) | ||||
| 			agent.Login(uid, s) | ||||
| 
 | ||||
| 			go func() { | ||||
| 				io.Copy(f, s) // stdin
 | ||||
| 			}() | ||||
| 			io.Copy(s, f) // stdout
 | ||||
| 			cmd.Wait() | ||||
| 			agent.LogOut(uid) | ||||
| 		} else { | ||||
| 			io.WriteString(s, "No PTY requested.\n") | ||||
| 			s.Exit(1) | ||||
| 		workingDirectory, _ := os.Getwd() | ||||
| 		env := append(os.Environ(), fmt.Sprintf("agentdir=%s", workingDirectory)) | ||||
| 		process, err := terminal.PtySpawner.Start(s, env, shellCommand) | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
| 		uid := int(time.Now().UnixMilli()) | ||||
| 		agent.Login(uid, s) | ||||
| 		go func() { | ||||
| 			io.Copy(process.Pipe(), s) // stdin
 | ||||
| 		}() | ||||
| 		io.Copy(s, process.Pipe()) // stdout
 | ||||
| 		process.Wait() | ||||
| 		agent.LogOut(uid) | ||||
| 	}) | ||||
| 
 | ||||
| 	log.Println("starting ssh server, waiting for debug sessions") | ||||
|  | ||||
| @ -1,98 +0,0 @@ | ||||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"syscall" | ||||
| 	"unsafe" | ||||
| 
 | ||||
| 	"golang.org/x/sys/windows" | ||||
| ) | ||||
| 
 | ||||
| func example() { | ||||
| 	// Create pipes for stdin and stdout
 | ||||
| 	var stdInRead, stdInWrite, stdOutRead, stdOutWrite windows.Handle | ||||
| 	sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})), InheritHandle: 1} | ||||
| 
 | ||||
| 	err := windows.CreatePipe(&stdInRead, &stdInWrite, sa, 0) | ||||
| 	if err != nil { | ||||
| 		fmt.Println("Error creating stdin pipe:", err) | ||||
| 		return | ||||
| 	} | ||||
| 	defer windows.CloseHandle(stdInRead) | ||||
| 	defer windows.CloseHandle(stdInWrite) | ||||
| 
 | ||||
| 	err = windows.CreatePipe(&stdOutRead, &stdOutWrite, sa, 0) | ||||
| 	if err != nil { | ||||
| 		fmt.Println("Error creating stdout pipe:", err) | ||||
| 		return | ||||
| 	} | ||||
| 	defer windows.CloseHandle(stdOutRead) | ||||
| 	defer windows.CloseHandle(stdOutWrite) | ||||
| 
 | ||||
| 	// Set the pipe to non-blocking mode
 | ||||
| 	mode := uint32(windows.PIPE_NOWAIT) | ||||
| 	err = windows.SetNamedPipeHandleState(stdInWrite, &mode, nil, nil) | ||||
| 	if err != nil { | ||||
| 		fmt.Println("Error setting stdin pipe to non-blocking:", err) | ||||
| 		return | ||||
| 	} | ||||
| 	err = windows.SetNamedPipeHandleState(stdOutRead, &mode, nil, nil) | ||||
| 	if err != nil { | ||||
| 		fmt.Println("Error setting stdout pipe to non-blocking:", err) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// Prepare process startup info
 | ||||
| 	si := &windows.StartupInfo{ | ||||
| 		Cb:        uint32(unsafe.Sizeof(windows.StartupInfo{})), | ||||
| 		Flags:     windows.STARTF_USESTDHANDLES, | ||||
| 		StdInput:  stdInRead, | ||||
| 		StdOutput: stdOutWrite, | ||||
| 		StdErr:    stdOutWrite, | ||||
| 	} | ||||
| 	pi := &windows.ProcessInformation{} | ||||
| 
 | ||||
| 	// Create the process
 | ||||
| 	cmd := "cmd.exe" | ||||
| 	err = windows.CreateProcess( | ||||
| 		nil, | ||||
| 		syscall.StringToUTF16Ptr(cmd), | ||||
| 		nil, | ||||
| 		nil, | ||||
| 		true, | ||||
| 		0, | ||||
| 		nil, | ||||
| 		nil, | ||||
| 		si, | ||||
| 		pi, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		fmt.Println("Error creating process:", err) | ||||
| 		return | ||||
| 	} | ||||
| 	defer windows.CloseHandle(pi.Process) | ||||
| 	defer windows.CloseHandle(pi.Thread) | ||||
| 
 | ||||
| 	// Write to the process
 | ||||
| 	message := "echo Hello, World!\r\n" | ||||
| 	var written uint32 | ||||
| 	err = windows.WriteFile(stdInWrite, []byte(message), &written, nil) | ||||
| 	if err != nil { | ||||
| 		fmt.Println("Error writing to process:", err) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// Read from the process
 | ||||
| 	buffer := make([]byte, 1024) | ||||
| 	var read uint32 | ||||
| 	err = windows.ReadFile(stdOutRead, buffer, &read, nil) | ||||
| 	if err != nil && err != windows.ERROR_NO_DATA { | ||||
| 		fmt.Println("Error reading from process:", err) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	fmt.Printf("Output: %s", buffer[:read]) | ||||
| 
 | ||||
| 	// Wait for the process to finish
 | ||||
| 	windows.WaitForSingleObject(pi.Process, windows.INFINITE) | ||||
| } | ||||
							
								
								
									
										18
									
								
								pkg/terminal/process.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								pkg/terminal/process.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,18 @@ | ||||
| package terminal | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/gliderlabs/ssh" | ||||
| 	"io" | ||||
| ) | ||||
| 
 | ||||
| type Process interface { | ||||
| 	Pipe() io.ReadWriter | ||||
| 	Wait() error | ||||
| } | ||||
| 
 | ||||
| // A function definition used to start processes
 | ||||
| type Spawner func(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) | ||||
| 
 | ||||
| func (s Spawner) Start(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) { | ||||
| 	return s(sshSession, env, name, arg...) | ||||
| } | ||||
							
								
								
									
										49
									
								
								pkg/terminal/pty_linux.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								pkg/terminal/pty_linux.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,49 @@ | ||||
| package terminal | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/creack/pty" | ||||
| 	"github.com/gliderlabs/ssh" | ||||
| 	"io" | ||||
| 	"os" | ||||
| 	"os/exec" | ||||
| 	"syscall" | ||||
| 	"unsafe" | ||||
| ) | ||||
| 
 | ||||
| var PtySpawner = Spawner(func(sshSession ssh.Session, env []string, name string, arg ...string) (Process, error) { | ||||
| 	cmd := exec.Command(name, arg...) | ||||
| 	ptyReq, winCh, isPty := sshSession.Pty() | ||||
| 	if !isPty { | ||||
| 		return nil, fmt.Errorf("ssh session is not a pty") | ||||
| 	} | ||||
| 	cmd.Env = append(env, | ||||
| 		fmt.Sprintf("TERM=%s", ptyReq.Term)) | ||||
| 	f, err := pty.Start(cmd) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	go func() { | ||||
| 		for win := range winCh { | ||||
| 			setWinsize(f, win.Width, win.Height) | ||||
| 		} | ||||
| 	}() | ||||
| 	return ptyProcess{cmd: cmd, f: f}, nil | ||||
| }) | ||||
| 
 | ||||
| func setWinsize(f *os.File, w, h int) { | ||||
| 	syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), | ||||
| 		uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) | ||||
| } | ||||
| 
 | ||||
| type ptyProcess struct { | ||||
| 	cmd *exec.Cmd | ||||
| 	f   *os.File | ||||
| } | ||||
| 
 | ||||
| func (p ptyProcess) Pipe() io.ReadWriter { | ||||
| 	return p.f | ||||
| } | ||||
| func (p ptyProcess) Wait() error { | ||||
| 	return p.cmd.Wait() | ||||
| } | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user