initialization of username, password on client (from server) and initialization of agentinfo on server is now done as soon as the agent registered and not through a side channel.
Making use of some simple utilities for GOB to make it easy to send objects over the line.
This commit is contained in:
		
							parent
							
								
									621bbd8ca6
								
							
						
					
					
						commit
						5a492f3855
					
				| @ -14,7 +14,6 @@ import ( | ||||
| 	"github.com/pkg/sftp" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| @ -33,7 +32,7 @@ import ( | ||||
| var hostPrivateKey []byte | ||||
| 
 | ||||
| func SftpHandler(sess ssh.Session) { | ||||
| 	uid := int(time.Now().UnixMilli()) | ||||
| 	uid := int(time.Now().UnixMicro()) | ||||
| 	agent.Login(uid, sess) | ||||
| 	defer agent.LogOut(uid) | ||||
| 
 | ||||
| @ -249,8 +248,9 @@ func main() { | ||||
| 	wsConn := websocketutil.NewWebSocketConn(conn) | ||||
| 	defer wsConn.Close() | ||||
| 
 | ||||
| 	err = comms.CheckProtocolVersion(comms.Agent, wsConn) | ||||
| 	serverInfo, err := comms.AgentInitialization(wsConn, comms.NewAgentInfo()) | ||||
| 	if err != nil { | ||||
| 		log.Printf("ERROR: %+v", err) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| 
 | ||||
| @ -261,7 +261,10 @@ func main() { | ||||
| 
 | ||||
| 	// Authentiocation
 | ||||
| 
 | ||||
| 	sshUserCredentials, passwordHandler, authorizedKeys := setupAuthentication(commChannel, authorizedKeysFile) | ||||
| 	passwordHandler, authorizedKeys := setupAuthentication( | ||||
| 		commChannel, | ||||
| 		serverInfo.UserPassword, | ||||
| 		authorizedKeysFile) | ||||
| 
 | ||||
| 	// Choose shell
 | ||||
| 
 | ||||
| @ -279,9 +282,9 @@ func main() { | ||||
| 	log.Println() | ||||
| 	clientUrl := strings.ReplaceAll(wsURL, "/agent/", "/client/") | ||||
| 	sshCommand := fmt.Sprintf("ssh -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\"  %s@localhost", | ||||
| 		clientUrl, sshUserCredentials.Username) | ||||
| 		clientUrl, serverInfo.UserPassword.Username) | ||||
| 	sftpCommand := fmt.Sprintf("sftp -oServerAliveInterval=10 -oProxyCommand=\"wsproxy %s\" %s@localhost", | ||||
| 		clientUrl, sshUserCredentials.Username) | ||||
| 		clientUrl, serverInfo.UserPassword.Username) | ||||
| 	log.Println("  # For SSH") | ||||
| 	log.Println("  " + sshCommand) | ||||
| 	log.Println() | ||||
| @ -302,26 +305,20 @@ func main() { | ||||
| 	service.Run(listener) | ||||
| } | ||||
| 
 | ||||
| func setupAuthentication(commChannel comms.CommChannel, authorizedKeysFile string) (comms.UserPassword, func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) { | ||||
| 	// Random user name and password so that effectively no one can login
 | ||||
| 	// until the user and password have been received from the server.
 | ||||
| 	sshUserCredentials := comms.UserPassword{ | ||||
| 		Username: strconv.Itoa(rand.Int()), | ||||
| 		Password: strconv.Itoa(rand.Int()), | ||||
| 	} | ||||
| func setupAuthentication(commChannel comms.CommChannel, | ||||
| 	userPassword comms.UserPassword, | ||||
| 	authorizedKeysFile string) (func(ctx ssh.Context, password string) bool, AuthorizedPublicKeys) { | ||||
| 
 | ||||
| 	passwordHandler := func(ctx ssh.Context, password string) bool { | ||||
| 		// Replace with your own logic to validate username and password
 | ||||
| 		return ctx.User() == sshUserCredentials.Username && password == sshUserCredentials.Password | ||||
| 		return ctx.User() == userPassword.Username && password == userPassword.Password | ||||
| 	} | ||||
| 	go comms.ListenForServerEvents(commChannel, func(user comms.UserPassword) { | ||||
| 		log.Println("Username and password configuration received from server") | ||||
| 		sshUserCredentials = user | ||||
| 	}) | ||||
| 	go comms.ListenForServerEvents(commChannel) | ||||
| 	authorizedKeys := ParseOpenSSHAuthorizedKeysFile(authorizedKeysFile) | ||||
| 	if len(authorizedKeys.keys) > 0 { | ||||
| 		log.Printf("A total of %d authorized ssh keys were found", len(authorizedKeys.keys)) | ||||
| 	} | ||||
| 	return sshUserCredentials, passwordHandler, authorizedKeys | ||||
| 	return passwordHandler, authorizedKeys | ||||
| } | ||||
| 
 | ||||
| func chooseShell() string { | ||||
|  | ||||
| @ -101,8 +101,10 @@ func ConfigureAgent(commChannel comms.CommChannel, | ||||
| 	log.Printf("Agent expires at %s", | ||||
| 		state.expiryTime(holdFilename).Format(time.DateTime)) | ||||
| 
 | ||||
| 	state.commChannel.SideChannel.Send(comms.NewAgentInfo()) | ||||
| 	state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) | ||||
| 	comms.Send(state.commChannel.SideChannel, | ||||
| 		comms.ConvergeMessage{ | ||||
| 			Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)), | ||||
| 		}) | ||||
| 
 | ||||
| 	go func() { | ||||
| 		for { | ||||
| @ -182,7 +184,10 @@ func login(sessionId int, sshSession ssh.Session) { | ||||
| 	if sessionType == "" { | ||||
| 		sessionType = "ssh" | ||||
| 	} | ||||
| 	state.commChannel.SideChannel.Send(comms.NewSessionInfo(sessionType)) | ||||
| 	comms.Send(state.commChannel.SideChannel, | ||||
| 		comms.ConvergeMessage{ | ||||
| 			Value: comms.NewSessionInfo(sessionType), | ||||
| 		}) | ||||
| 
 | ||||
| 	holdFileStats, ok := fileExistsWithStats(holdFilename) | ||||
| 	if ok { | ||||
| @ -297,7 +302,10 @@ func holdFileChange() { | ||||
| 		message += holdFileMessage() | ||||
| 		messageUsers(message) | ||||
| 		state.lastExpiryTimmeReported = newExpiryTIme | ||||
| 		state.commChannel.SideChannel.Send(comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename))) | ||||
| 		comms.Send(state.commChannel.SideChannel, | ||||
| 			comms.ConvergeMessage{ | ||||
| 				Value: comms.NewExpiryTimeUpdate(state.expiryTime(holdFilename)), | ||||
| 			}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -8,6 +8,8 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const MESSAGE_TIMEOUT = 10 * time.Second | ||||
| 
 | ||||
| type CommChannel struct { | ||||
| 	// a separet connection outside of the ssh session
 | ||||
| 	SideChannel GOBChannel | ||||
| @ -70,7 +72,10 @@ func NewCommChannel(role Role, wsConn io.ReadWriteCloser) (CommChannel, error) { | ||||
| 		go func() { | ||||
| 			for { | ||||
| 				time.Sleep(10 * time.Second) | ||||
| 				err := commChannel.SideChannel.Send(HeartBeat{}) | ||||
| 				err := Send(commChannel.SideChannel, | ||||
| 					ConvergeMessage{ | ||||
| 						Value: HeartBeat{}, | ||||
| 					}) | ||||
| 				if err != nil { | ||||
| 					log.Println("Sending heartbeat to server failed") | ||||
| 				} | ||||
| @ -113,13 +118,12 @@ func ListenForAgentEvents(channel GOBChannel, | ||||
| 			// intended to keep the connection up
 | ||||
| 
 | ||||
| 		default: | ||||
| 			fmt.Printf("  Unknown type: %T\n", v) | ||||
| 			fmt.Printf("  Unknown type: %v %T\n", v, v) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func ListenForServerEvents(channel CommChannel, | ||||
| 	setUsernamePassword func(user UserPassword)) { | ||||
| func ListenForServerEvents(channel CommChannel) { | ||||
| 	for { | ||||
| 		var result ConvergeMessage | ||||
| 		err := channel.SideChannel.Decoder.Decode(&result) | ||||
| @ -129,10 +133,9 @@ func ListenForServerEvents(channel CommChannel, | ||||
| 			log.Printf("Exiting agent listener %v", err) | ||||
| 			return | ||||
| 		} | ||||
| 		switch v := result.Value.(type) { | ||||
| 
 | ||||
| 		case UserPassword: | ||||
| 			setUsernamePassword(v) | ||||
| 		// no supported server events at this time.
 | ||||
| 		switch v := result.Value.(type) { | ||||
| 
 | ||||
| 		default: | ||||
| 			fmt.Printf("  Unknown type: %T\n", v) | ||||
| @ -140,25 +143,102 @@ func ListenForServerEvents(channel CommChannel, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func CheckProtocolVersion(role Role, conn io.ReadWriter) error { | ||||
| func AgentInitialization(conn io.ReadWriter, agentInto AgentInfo) (ServerInfo, error) { | ||||
| 	channel := NewGOBChannel(conn) | ||||
| 	err := CheckProtocolVersion(Agent, channel) | ||||
| 
 | ||||
| 	err = SendWithTimeout(channel, agentInto) | ||||
| 	if err != nil { | ||||
| 		return ServerInfo{}, nil | ||||
| 	} | ||||
| 	serverInfo, err := ReceiveWithTimeout[ServerInfo](channel) | ||||
| 	if err != nil { | ||||
| 		return ServerInfo{}, nil | ||||
| 	} | ||||
| 	// TODO remove logging
 | ||||
| 	log.Println("Server info received: ", serverInfo) | ||||
| 
 | ||||
| 	return serverInfo, err | ||||
| } | ||||
| 
 | ||||
| func ServerInitialization(conn io.ReadWriter, serverInfo ServerInfo) (AgentInfo, error) { | ||||
| 	channel := NewGOBChannel(conn) | ||||
| 	err := CheckProtocolVersion(ConvergeServer, channel) | ||||
| 
 | ||||
| 	agentInfo, err := ReceiveWithTimeout[AgentInfo](channel) | ||||
| 	if err != nil { | ||||
| 		return AgentInfo{}, err | ||||
| 	} | ||||
| 	log.Println("Agent info received: ", agentInfo) | ||||
| 	err = SendWithTimeout(channel, serverInfo) | ||||
| 	if err != nil { | ||||
| 		return AgentInfo{}, nil | ||||
| 	} | ||||
| 	return agentInfo, err | ||||
| } | ||||
| 
 | ||||
| // Events sent over the websocket connection that is established between
 | ||||
| // agent and converge server. This is done as soon as the agent starts.
 | ||||
| 
 | ||||
| // First commmunication between agent and Converge Server
 | ||||
| // Both exchange their protocol version and if it is incorrect, the session
 | ||||
| // is terminated.
 | ||||
| 
 | ||||
| func CheckProtocolVersion(role Role, channel GOBChannel) error { | ||||
| 	log.Println("ROLE ", role) | ||||
| 	switch role { | ||||
| 	case Agent: | ||||
| 		err := SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		version, err := ReceiveWithTimeout[ProtocolVersion](channel) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if version.Version != PROTOCOL_VERSION { | ||||
| 			return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d", | ||||
| 				PROTOCOL_VERSION, version.Version) | ||||
| 		} | ||||
| 		return nil | ||||
| 	case ConvergeServer: | ||||
| 		version, err := ReceiveWithTimeout[ProtocolVersion](channel) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		err = SendWithTimeout(channel, ProtocolVersion{Version: PROTOCOL_VERSION}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if version.Version != PROTOCOL_VERSION { | ||||
| 			return fmt.Errorf("Protocol version mismatch: agent %d, converge server %d", | ||||
| 				PROTOCOL_VERSION, version.Version) | ||||
| 		} | ||||
| 		return nil | ||||
| 	default: | ||||
| 		panic(fmt.Errorf("unexpected rolg %v", role)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func CheckProtocolVersionOld(role Role, conn io.ReadWriter) error { | ||||
| 	channel := NewGOBChannel(conn) | ||||
| 
 | ||||
| 	sends := make(chan any) | ||||
| 	receives := make(chan any) | ||||
| 	sends := make(chan bool) | ||||
| 	receives := make(chan ProtocolVersion) | ||||
| 	errors := make(chan error) | ||||
| 
 | ||||
| 	channel.SendAsync(ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors) | ||||
| 	channel.ReceiveAsync(receives, errors) | ||||
| 	SendAsync(channel, ProtocolVersion{Version: PROTOCOL_VERSION}, sends, errors) | ||||
| 	ReceiveAsync(channel, receives, errors) | ||||
| 
 | ||||
| 	select { | ||||
| 	case <-time.After(10 * time.Second): | ||||
| 	case <-time.After(MESSAGE_TIMEOUT): | ||||
| 		log.Println("PROTOCOLVERSION: timeout") | ||||
| 		return fmt.Errorf("Timeout waiting for protocol version") | ||||
| 	case err := <-errors: | ||||
| 		log.Printf("PROTOCOLVERSION: %v", err) | ||||
| 		return err | ||||
| 	case protocolVersion := <-receives: | ||||
| 		otherVersion := protocolVersion.(ProtocolVersion).Version | ||||
| 		otherVersion := protocolVersion.Version | ||||
| 		if PROTOCOL_VERSION != otherVersion { | ||||
| 			switch role { | ||||
| 			case Agent: | ||||
| @ -170,7 +250,12 @@ func CheckProtocolVersion(role Role, conn io.ReadWriter) error { | ||||
| 			} | ||||
| 			return fmt.Errorf("Protocol version mismatch") | ||||
| 		} | ||||
| 		log.Printf("PROTOCOLVERSION: %v", protocolVersion.(ProtocolVersion).Version) | ||||
| 		log.Printf("PROTOCOLVERSION: %v", protocolVersion.Version) | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Session info metadata exchange. These are sent over the SSH connection. The agent embedded
 | ||||
| // ssh serverlisterns for connections, but we provide a custom listener (AgentListener) that
 | ||||
| // decorates the yamux Session (which is a listener) and uses this connection to exchange some
 | ||||
| // metadata before the connection is handed back to SSH.
 | ||||
|  | ||||
| @ -23,6 +23,10 @@ type AgentInfo struct { | ||||
| 	OS       string | ||||
| } | ||||
| 
 | ||||
| type ClientInfo struct { | ||||
| 	ClientId string | ||||
| } | ||||
| 
 | ||||
| type SessionInfo struct { | ||||
| 	// "ssh", "sftp"
 | ||||
| 	SessionType string | ||||
| @ -47,8 +51,7 @@ type UserPassword struct { | ||||
| 	Password string | ||||
| } | ||||
| 
 | ||||
| type ConnectionInfo struct { | ||||
| 	ConnectionId int | ||||
| type ServerInfo struct { | ||||
| 	UserPassword UserPassword | ||||
| } | ||||
| 
 | ||||
| @ -88,7 +91,6 @@ func RegisterEventsWithGob() { | ||||
| 	// ConvergeServer to Agent
 | ||||
| 	gob.Register(ProtocolVersion{}) | ||||
| 	gob.Register(UserPassword{}) | ||||
| 	gob.Register(ConnectionInfo{}) | ||||
| 
 | ||||
| 	// Wrapper event.
 | ||||
| 	gob.Register(ConvergeMessage{}) | ||||
|  | ||||
| @ -2,8 +2,10 @@ package comms | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/gob" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| type GOBChannel struct { | ||||
| @ -22,12 +24,26 @@ func NewGOBChannel(conn io.ReadWriter) GOBChannel { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Send(channel GOBChannel, object any) error { | ||||
| 	err := channel.Encoder.Encode(object) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Encoding error %v", err) | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func Receive[T any](channel GOBChannel) (T, error) { | ||||
| 	target := *new(T) | ||||
| 	err := channel.Decoder.Decode(&target) | ||||
| 	return target, err | ||||
| } | ||||
| 
 | ||||
| // Asynchronous send and receive on a single connection is guaranteed to preserver ordering of
 | ||||
| // messages. We use asynchronous to void blocking indefinitely or depending on network timeouts.
 | ||||
| 
 | ||||
| func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- error) { | ||||
| func SendAsync[T any](channel GOBChannel, obj T, done chan<- bool, errors chan<- error) { | ||||
| 	go func() { | ||||
| 		err := channel.Send(obj) | ||||
| 		err := Send(channel, obj) | ||||
| 		if err != nil { | ||||
| 			errors <- err | ||||
| 		} else { | ||||
| @ -36,9 +52,9 @@ func (channel GOBChannel) SendAsync(obj any, done chan<- any, errors chan<- erro | ||||
| 	}() | ||||
| } | ||||
| 
 | ||||
| func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) { | ||||
| func ReceiveAsync[T any](channel GOBChannel, result chan T, errors chan<- error) { | ||||
| 	go func() { | ||||
| 		value, err := channel.Receive() | ||||
| 		value, err := Receive[T](channel) | ||||
| 		if err != nil { | ||||
| 			errors <- err | ||||
| 		} else { | ||||
| @ -47,16 +63,32 @@ func (channel GOBChannel) ReceiveAsync(result chan<- any, errors chan<- error) { | ||||
| 	}() | ||||
| } | ||||
| 
 | ||||
| func (channel GOBChannel) Send(object any) error { | ||||
| 	err := channel.Encoder.Encode(ConvergeMessage{Value: object}) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Encoding error %v", err) | ||||
| func SendWithTimeout[T any](channel GOBChannel, obj T) error { | ||||
| 	done := make(chan bool) | ||||
| 	errors := make(chan error) | ||||
| 
 | ||||
| 	SendAsync(channel, obj, done, errors) | ||||
| 	select { | ||||
| 	case <-time.After(MESSAGE_TIMEOUT): | ||||
| 		return fmt.Errorf("Timeout in SwndWithTimout") | ||||
| 	case err := <-errors: | ||||
| 		return err | ||||
| 	case <-done: | ||||
| 		return nil | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (channel GOBChannel) Receive() (any, error) { | ||||
| 	var target ConvergeMessage | ||||
| 	err := channel.Decoder.Decode(&target) | ||||
| 	return target.Value, err | ||||
| func ReceiveWithTimeout[T any](channel GOBChannel) (T, error) { | ||||
| 	result := make(chan T) | ||||
| 	errors := make(chan error) | ||||
| 
 | ||||
| 	ReceiveAsync(channel, result, errors) | ||||
| 	select { | ||||
| 	case <-time.After(MESSAGE_TIMEOUT): | ||||
| 		return *new(T), fmt.Errorf("Timeout in ReceiveWithTimout") | ||||
| 	case err := <-errors: | ||||
| 		return *new(T), err | ||||
| 	case value := <-result: | ||||
| 		return value, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -29,11 +29,12 @@ type Client struct { | ||||
| 	sessionType string | ||||
| } | ||||
| 
 | ||||
| func NewAgent(commChannel comms.CommChannel, publicId string) *Agent { | ||||
| func NewAgent(commChannel comms.CommChannel, publicId string, agentInfo comms.AgentInfo) *Agent { | ||||
| 	return &Agent{ | ||||
| 		commChannel: commChannel, | ||||
| 		publicId:    publicId, | ||||
| 		startTime:   time.Now(), | ||||
| 		agentInfo:   agentInfo, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| @ -88,7 +89,7 @@ func (admin *Admin) logStatus() { | ||||
| 	log.Printf("\n") | ||||
| } | ||||
| 
 | ||||
| func (admin *Admin) addAgent(publicId string, conn io.ReadWriteCloser) (*Agent, error) { | ||||
| func (admin *Admin) addAgent(publicId string, agentInfo comms.AgentInfo, conn io.ReadWriteCloser) (*Agent, error) { | ||||
| 	admin.mutex.Lock() | ||||
| 	defer admin.mutex.Unlock() | ||||
| 
 | ||||
| @ -102,7 +103,7 @@ func (admin *Admin) addAgent(publicId string, conn io.ReadWriteCloser) (*Agent, | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	agent = NewAgent(commChannel, publicId) | ||||
| 	agent = NewAgent(commChannel, publicId, agentInfo) | ||||
| 	admin.agents[publicId] = agent | ||||
| 	admin.logStatus() | ||||
| 	return agent, nil | ||||
| @ -172,13 +173,16 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser, | ||||
| 	userPassword comms.UserPassword) error { | ||||
| 	defer conn.Close() | ||||
| 
 | ||||
| 	err := comms.CheckProtocolVersion(comms.ConvergeServer, conn) | ||||
| 	serverInfo := comms.ServerInfo{ | ||||
| 		UserPassword: userPassword, | ||||
| 	} | ||||
| 
 | ||||
| 	agentInfo, err := comms.ServerInitialization(conn, serverInfo) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: remove agent return value
 | ||||
| 	agent, err := admin.addAgent(publicId, conn) | ||||
| 	agent, err := admin.addAgent(publicId, agentInfo, conn) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -186,9 +190,6 @@ func (admin *Admin) Register(publicId string, conn io.ReadWriteCloser, | ||||
| 		admin.RemoveAgent(publicId) | ||||
| 	}() | ||||
| 
 | ||||
| 	log.Println("Sending username and password to agent") | ||||
| 	agent.commChannel.SideChannel.Send(userPassword) | ||||
| 
 | ||||
| 	go func() { | ||||
| 		comms.ListenForAgentEvents(agent.commChannel.SideChannel, | ||||
| 			func(info comms.AgentInfo) { | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user