From bcc6b2ee5c6bd61bc1157dce1dc8a0676c912c8b Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Wed, 13 May 2026 20:03:58 +0300 Subject: [PATCH] feat: remove unused client ID from config --- cmd/olcrtc/main_test.go | 5 +- internal/app/session/session.go | 6 - internal/app/session/session_test.go | 10 - internal/client/client.go | 140 ++++++++++-- internal/client/client_test.go | 6 +- internal/config/config.go | 4 +- internal/config/config_test.go | 3 +- internal/e2e/tunnel_test.go | 67 +----- internal/handshake/handshake.go | 214 ++++++++++++++++++ internal/handshake/handshake_test.go | 128 +++++++++++ internal/link/direct/direct.go | 2 +- internal/link/direct/direct_test.go | 4 +- internal/link/link.go | 2 +- internal/link/link_test.go | 4 +- internal/server/server.go | 109 +++++++-- internal/server/server_test.go | 32 ++- internal/transport/transport.go | 2 +- internal/transport/transport_test.go | 4 +- internal/transport/vp8channel/transport.go | 2 +- .../vp8channel/transport_unit_test.go | 2 +- mobile/mobile.go | 6 +- mobile/mobile_test.go | 4 +- 22 files changed, 600 insertions(+), 156 deletions(-) create mode 100644 internal/handshake/handshake.go create mode 100644 internal/handshake/handshake_test.go diff --git a/cmd/olcrtc/main_test.go b/cmd/olcrtc/main_test.go index 44f93ef..18f4ddf 100644 --- a/cmd/olcrtc/main_test.go +++ b/cmd/olcrtc/main_test.go @@ -83,7 +83,6 @@ func TestRunWithConfigValidationAndDataDirErrors(t *testing.T) { Link: "direct", Transport: "datachannel", Auth: "jazz", - ClientID: "client", KeyHex: "key", DNSServer: "1.1.1.1:53", } @@ -113,7 +112,7 @@ func TestRunWithArgsSuccessfulSessionReturn(t *testing.T) { called := false runSession = func(ctx context.Context, cfg session.Config) error { called = true - if cfg.Mode != "srv" || cfg.Auth != "jazz" || cfg.ClientID != "client" { + if cfg.Mode != "srv" || cfg.Auth != "jazz" { t.Fatalf("session config = %+v", cfg) } select { @@ -129,8 +128,6 @@ mode: srv link: direct auth: provider: jazz -room: - client_id: client crypto: key: key net: diff --git a/internal/app/session/session.go b/internal/app/session/session.go index dfd21a4..9bbea71 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -101,8 +101,6 @@ var ( ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)") // ErrSOCKSPortRequired indicates that socks port is required for cnc mode. ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)") - // ErrClientIDRequired indicates that client ID is required. - ErrClientIDRequired = errors.New("client ID required (use -client-id )") ) // Config holds runtime session settings. @@ -115,7 +113,6 @@ type Config struct { URL string Token string RoomID string - ClientID string KeyHex string SOCKSHost string SOCKSPort int @@ -242,9 +239,6 @@ func validateCommon(cfg Config) error { if cfg.RoomID == "" && cfg.Auth != authJazz && cfg.Auth != authNone { return ErrRoomIDRequired } - if cfg.ClientID == "" { - return ErrClientIDRequired - } if cfg.KeyHex == "" { return ErrKeyRequired } diff --git a/internal/app/session/session_test.go b/internal/app/session/session_test.go index ab705fe..6ca3f79 100644 --- a/internal/app/session/session_test.go +++ b/internal/app/session/session_test.go @@ -16,7 +16,6 @@ func TestValidate(t *testing.T) { Transport: "datachannel", Auth: "telemost", RoomID: "room-1", - ClientID: "client-1", KeyHex: "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff", DNSServer: "1.1.1.1:53", //nolint:goconst // test literal, repetition is intentional } @@ -91,15 +90,6 @@ func TestValidate(t *testing.T) { }(), want: ErrRoomIDRequired, }, - { - name: "client id required", - cfg: func() Config { - cfg := base - cfg.ClientID = "" - return cfg - }(), - want: ErrClientIDRequired, - }, { name: "key required", cfg: func() Config { diff --git a/internal/client/client.go b/internal/client/client.go index 06a4a94..0b81275 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -10,10 +10,15 @@ import ( "fmt" "io" "net" + "os" + "path/filepath" + "strings" "sync" "time" + "github.com/google/uuid" "github.com/openlibrecommunity/olcrtc/internal/crypto" + "github.com/openlibrecommunity/olcrtc/internal/handshake" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/muxconn" @@ -44,15 +49,18 @@ var ( // Client handles local SOCKS5 connections and tunnels them to the server. type Client struct { - ln link.Link - cipher *crypto.Cipher - conn *muxconn.Conn - session *smux.Session - sessMu sync.RWMutex - clientID string - dnsServer string - socksUser string - socksPass string + ln link.Link + cipher *crypto.Cipher + conn *muxconn.Conn + session *smux.Session + controlStrm *smux.Stream + sessMu sync.RWMutex + deviceID string + sessionID string + claims map[string]any + dnsServer string + socksUser string + socksPass string } // Config holds runtime configuration for [Run] and [RunWithReady]. @@ -62,7 +70,6 @@ type Config struct { Carrier string RoomURL string KeyHex string - ClientID string LocalAddr string DNSServer string SOCKSUser string @@ -86,6 +93,19 @@ type Config struct { Engine string URL string Token string + + // DeviceID overrides the persistent client-side device identifier. Leave + // empty to derive one from DeviceIDPath (or generate a random one if both + // are empty). + DeviceID string + + // DeviceIDPath is a file in which to persist the auto-generated device ID + // across restarts. Ignored when DeviceID is set explicitly. + DeviceIDPath string + + // Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to + // the server's AuthHook. Free-form key/value bag for plan, user, region, etc. + Claims map[string]any } // Run starts the client with the given configuration. @@ -103,9 +123,15 @@ func RunWithReady(ctx context.Context, cfg Config, onReady func()) error { return fmt.Errorf("setupCipher failed: %w", err) } + deviceID, err := resolveDeviceID(cfg.DeviceID, cfg.DeviceIDPath) + if err != nil { + return fmt.Errorf("resolve device id: %w", err) + } + c := &Client{ cipher: cipher, - clientID: cfg.ClientID, + deviceID: deviceID, + claims: cfg.Claims, dnsServer: cfg.DNSServer, socksUser: cfg.SOCKSUser, socksPass: cfg.SOCKSPass, @@ -147,7 +173,7 @@ func (c *Client) bringUpLink( Engine: cfg.Engine, URL: cfg.URL, Token: cfg.Token, - ClientID: c.clientID, + DeviceID: c.deviceID, Name: names.Generate(), OnData: c.onData, DNSServer: cfg.DNSServer, @@ -188,14 +214,80 @@ func (c *Client) bringUpLink( if err != nil { return fmt.Errorf("smux client: %w", err) } + + control, sid, err := openControlStream(sess, c.deviceID, c.claims) + if err != nil { + _ = sess.Close() + _ = c.conn.Close() + return fmt.Errorf("handshake: %w", err) + } + logger.Infof("session %s opened (device=%s)", sid, c.deviceID) + c.sessMu.Lock() c.session = sess + c.controlStrm = control + c.sessionID = sid c.sessMu.Unlock() go ln.WatchConnection(ctx) return nil } +// openControlStream opens stream #1 on sess and performs the handshake. +// The stream stays open for the lifetime of the smux session — the server +// holds it parked, and it would carry future control messages. +func openControlStream( + sess *smux.Session, + deviceID string, + claims map[string]any, +) (*smux.Stream, string, error) { + stream, err := sess.OpenStream() + if err != nil { + return nil, "", fmt.Errorf("open control stream: %w", err) + } + _ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout)) + sid, err := handshake.Client(stream, deviceID, claims) + _ = stream.SetDeadline(time.Time{}) + if err != nil { + _ = stream.Close() + return nil, "", err + } + return stream, sid, nil +} + +// resolveDeviceID returns the device ID to send in CLIENT_HELLO. +// +// Precedence: +// 1. Explicit deviceID arg (Config.DeviceID) — used verbatim. +// 2. Persistent file at path (Config.DeviceIDPath) — read if it exists, +// otherwise generated and written for future runs. +// 3. Random UUID per run when both inputs are empty. +func resolveDeviceID(deviceID, path string) (string, error) { + if deviceID != "" { + return deviceID, nil + } + if path == "" { + return uuid.NewString(), nil + } + data, err := os.ReadFile(path) + if err == nil { + id := strings.TrimSpace(string(data)) + if id != "" { + return id, nil + } + } else if !errors.Is(err, os.ErrNotExist) { + return "", fmt.Errorf("read device id %s: %w", path, err) + } + id := uuid.NewString() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return "", fmt.Errorf("mkdir device id dir: %w", err) + } + if err := os.WriteFile(path, []byte(id+"\n"), 0o600); err != nil { + return "", fmt.Errorf("write device id %s: %w", path, err) + } + return id, nil +} + // smuxConfig returns the tuned smux config used on both ends. func smuxConfig() *smux.Config { cfg := smux.DefaultConfig() @@ -212,6 +304,10 @@ func smuxConfig() *smux.Config { func (c *Client) handleReconnect() { logger.Infof("client link reconnect - tearing down smux session") c.sessMu.Lock() + if c.controlStrm != nil { + _ = c.controlStrm.Close() + c.controlStrm = nil + } if c.session != nil { _ = c.session.Close() c.session = nil @@ -220,6 +316,7 @@ func (c *Client) handleReconnect() { _ = c.conn.Close() c.conn = nil } + c.sessionID = "" c.sessMu.Unlock() c.conn = muxconn.New(c.ln, c.cipher) sess, err := smux.Client(c.conn, smuxConfig()) @@ -227,13 +324,25 @@ func (c *Client) handleReconnect() { logger.Warnf("smux re-init failed: %v", err) return } + control, sid, err := openControlStream(sess, c.deviceID, c.claims) + if err != nil { + logger.Warnf("handshake on reconnect failed: %v", err) + _ = sess.Close() + return + } + logger.Infof("session %s reopened (device=%s)", sid, c.deviceID) c.sessMu.Lock() c.session = sess + c.controlStrm = control + c.sessionID = sid c.sessMu.Unlock() } func (c *Client) shutdown() { c.sessMu.Lock() + if c.controlStrm != nil { + _ = c.controlStrm.Close() + } if c.session != nil { _ = c.session.Close() } @@ -340,10 +449,9 @@ func (c *Client) tunnel(conn net.Conn, sess *smux.Session, targetAddr string, ta func (c *Client) sendConnectRequest(stream *smux.Stream, targetAddr string, targetPort int) error { connectReq, err := json.Marshal(map[string]any{ - "cmd": "connect", - "clientId": c.clientID, - "addr": targetAddr, - "port": targetPort, + "cmd": "connect", + "addr": targetAddr, + "port": targetPort, }) if err != nil { return fmt.Errorf("sid=%d marshal connect req: %w", stream.ID(), err) diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 3aa146b..ebe6745 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -417,7 +417,7 @@ func TestSendConnectRequestOverSmux(t *testing.T) { done <- err return } - if req["cmd"] != "connect" || req["clientId"] != "client-1" || req["addr"] != "example.com" { //nolint:goconst,lll // test literal, repetition is intentional + if req["cmd"] != "connect" || req["addr"] != "example.com" { //nolint:goconst,lll // test literal, repetition is intentional done <- errUnexpectedConnectRequest return } @@ -431,7 +431,7 @@ func TestSendConnectRequestOverSmux(t *testing.T) { } defer func() { _ = stream.Close() }() - c := &Client{clientID: "client-1"} + c := &Client{deviceID: "client-1"} if err := c.sendConnectRequest(stream, "example.com", 443); err != nil { t.Fatalf("sendConnectRequest() error = %v", err) } @@ -473,7 +473,7 @@ func TestSendConnectRequestRejectsBadAck(t *testing.T) { } defer func() { _ = stream.Close() }() - c := &Client{clientID: "client-1"} + c := &Client{deviceID: "client-1"} if err := c.sendConnectRequest(stream, "example.com", 443); !errors.Is(err, ErrRemoteNotReady) { t.Fatalf("sendConnectRequest() error = %v, want %v", err, ErrRemoteNotReady) } diff --git a/internal/config/config.go b/internal/config/config.go index 9b60e71..9fcad0a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -45,8 +45,7 @@ type Auth struct { // Room identifies the conference room. type Room struct { - ID string `yaml:"id"` - ClientID string `yaml:"client_id"` // deprecated: server identifier (will be removed) + ID string `yaml:"id"` } // Crypto holds the shared secret used to authenticate and encrypt the tunnel. @@ -137,7 +136,6 @@ func Apply(dst session.Config, f File) session.Config { dst.URL = pickString(dst.URL, f.Engine.URL) dst.Token = pickString(dst.Token, f.Engine.Token) dst.RoomID = pickString(dst.RoomID, f.Room.ID) - dst.ClientID = pickString(dst.ClientID, f.Room.ClientID) dst.KeyHex = pickString(dst.KeyHex, f.Crypto.Key) dst.SOCKSHost = pickString(dst.SOCKSHost, f.SOCKS.Host) dst.SOCKSPort = pickInt(dst.SOCKSPort, f.SOCKS.Port) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 9c54d72..6c402b2 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -18,7 +18,6 @@ auth: provider: wbstream room: id: r1 - client_id: c1 crypto: key: deadbeef net: @@ -50,7 +49,7 @@ debug: true got := Apply(session.Config{}, f) if got.Mode != "srv" || got.Link != "direct" || got.Auth != "wbstream" || - got.RoomID != "r1" || got.ClientID != "c1" || got.KeyHex != "deadbeef" || + got.RoomID != "r1" || got.KeyHex != "deadbeef" || got.Transport != "datachannel" || got.DNSServer != "1.1.1.1:53" || got.SOCKSHost != "127.0.0.1" || got.SOCKSPort != 1080 || got.SOCKSUser != "u" || got.SOCKSPass != "p" || diff --git a/internal/e2e/tunnel_test.go b/internal/e2e/tunnel_test.go index dfb036f..b2aad1b 100644 --- a/internal/e2e/tunnel_test.go +++ b/internal/e2e/tunnel_test.go @@ -400,7 +400,6 @@ func validSessionConfig(mode, carrierName, transportName string) session.Config Transport: transportName, Auth: carrierName, RoomID: "room", - ClientID: "client-1", KeyHex: testKeyHex, SOCKSHost: "127.0.0.1", SOCKSPort: 1080, @@ -428,7 +427,7 @@ func validLinkConfig(carrierName, transportName string) link.Config { Transport: cfg.Transport, Carrier: cfg.Auth, RoomURL: "room", - ClientID: cfg.ClientID, + DeviceID: "e2e-link-test", Name: "e2e-" + carrierName + "-" + transportName, DNSServer: cfg.DNSServer, VideoWidth: cfg.VideoWidth, @@ -505,7 +504,7 @@ type tunnelRuntime struct { clientErr chan error } -func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRuntime { +func startTunnel(t *testing.T, deviceID, _ string) *tunnelRuntime { t.Helper() carrierName, room := registerMemoryCarrier(t) @@ -521,7 +520,6 @@ func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRun Carrier: carrierName, RoomURL: "room", KeyHex: testKeyHex, - ClientID: serverClientID, DNSServer: "127.0.0.1:53", }) }() @@ -536,7 +534,7 @@ func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRun Carrier: carrierName, RoomURL: "room", KeyHex: testKeyHex, - ClientID: clientClientID, + DeviceID: deviceID, LocalAddr: socksAddr, DNSServer: "127.0.0.1:53", }, func() { close(ready) }) @@ -555,7 +553,7 @@ func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRun func startRealTunnel( ctx context.Context, t *testing.T, - carrierName, transportName, roomURL, serverClientID, clientClientID string, + carrierName, transportName, roomURL, _, clientDeviceID string, ) (*tunnelRuntime, error) { t.Helper() @@ -573,7 +571,6 @@ func startRealTunnel( Carrier: carrierName, RoomURL: roomURL, KeyHex: testKeyHex, - ClientID: serverClientID, DNSServer: "127.0.0.1:53", VideoWidth: 1080, VideoHeight: 1080, @@ -613,7 +610,7 @@ func startRealTunnel( Carrier: carrierName, RoomURL: roomURL, KeyHex: testKeyHex, - ClientID: clientClientID, + DeviceID: clientDeviceID, LocalAddr: socksAddr, DNSServer: "127.0.0.1:53", VideoWidth: 1080, @@ -749,49 +746,6 @@ func connectViaSOCKS(t *testing.T, socksAddr, targetAddr string) net.Conn { return conn } -func connectViaSOCKSExpectFailure(t *testing.T, socksAddr, targetAddr string) []byte { - t.Helper() - - dialer := net.Dialer{Timeout: 2 * time.Second} - conn, err := dialer.DialContext(context.Background(), "tcp4", socksAddr) - if err != nil { - t.Fatalf("dial socks: %v", err) - } - defer func() { _ = conn.Close() }() - - if _, err := conn.Write([]byte{5, 1, 0}); err != nil { - t.Fatalf("write socks greeting: %v", err) - } - greeting := make([]byte, 2) - if _, err := io.ReadFull(conn, greeting); err != nil { - t.Fatalf("read socks greeting: %v", err) - } - - host, portText, err := net.SplitHostPort(targetAddr) - if err != nil { - t.Fatalf("split target addr: %v", err) - } - port, err := strconv.Atoi(portText) - if err != nil { - t.Fatalf("parse target port: %v", err) - } - req := make([]byte, 0, 10) - req = append(req, 5, 1, 0, 1) - req = append(req, net.ParseIP(host).To4()...) - var portBuf [2]byte - binary.BigEndian.PutUint16(portBuf[:], uint16(port)) //nolint:gosec // SOCKS5 port is uint16 by definition - req = append(req, portBuf[:]...) - if _, err := conn.Write(req); err != nil { - t.Fatalf("write socks connect: %v", err) - } - - reply := make([]byte, 10) - if _, err := io.ReadFull(conn, reply); err != nil { - t.Fatalf("read socks failure reply: %v", err) - } - return reply -} - func TestBuiltInProviderTransportMatrixValidates(t *testing.T) { session.RegisterDefaults() @@ -971,17 +925,6 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { } } -func TestWrongClientIDIsRejected(t *testing.T) { - echoAddr := startEchoServer(t) - rt := startTunnel(t, "server-client", "wrong-client") - defer rt.stop(t) - - reply := connectViaSOCKSExpectFailure(t, rt.socksAddr, echoAddr) - if !bytes.Equal(reply, []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}) { - t.Fatalf("wrong client-id reply = %v, want host unreachable", reply) - } -} - func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) { echoAddr := startEchoServer(t) rt := startTunnel(t, "client-1", "client-1") diff --git a/internal/handshake/handshake.go b/internal/handshake/handshake.go new file mode 100644 index 0000000..9d66f15 --- /dev/null +++ b/internal/handshake/handshake.go @@ -0,0 +1,214 @@ +// Package handshake implements the olcrtc session handshake. +// +// The handshake runs on the first smux stream (control stream) of a tunnel. +// Wire format on the control stream is length-prefixed JSON: each message is +// a 4-byte big-endian length followed by that many bytes of JSON. +// +// client server +// │ CLIENT_HELLO │ +// │ ─────────────────────► │ +// │ │ AuthHook(claims) → sessionID | err +// │ SERVER_WELCOME / REJECT│ +// │ ◄───────────────────── │ +// │ │ +// +// After the exchange the control stream stays open; tunnel traffic flows over +// additional smux streams opened by the client. The control stream may carry +// keepalives or future control messages. +package handshake + +import ( + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "time" +) + +// ProtoVersion identifies the wire-format version. Bumped only on breaking +// changes to message layout or semantics. +const ProtoVersion = 1 + +// MaxMessageSize caps a single handshake frame. 64 KiB is comfortably larger +// than any legitimate HELLO/WELCOME payload and prevents memory blowups from +// malicious peers. +const MaxMessageSize = 64 * 1024 + +// DefaultTimeout bounds how long either side will wait for the peer's reply +// before bailing out. +const DefaultTimeout = 15 * time.Second + +// MsgType labels each protocol message. +type MsgType string + +const ( + // TypeHello is the client's first message. + TypeHello MsgType = "CLIENT_HELLO" + // TypeWelcome is the server's success reply. + TypeWelcome MsgType = "SERVER_WELCOME" + // TypeReject is the server's failure reply. + TypeReject MsgType = "SERVER_REJECT" +) + +// Hello is sent by the client to begin a session. +type Hello struct { + Version int `json:"version"` + Type MsgType `json:"type"` + DeviceID string `json:"device_id"` + Claims map[string]any `json:"claims,omitempty"` +} + +// Welcome is the server's response on a successful handshake. +type Welcome struct { + Version int `json:"version"` + Type MsgType `json:"type"` + SessionID string `json:"session_id"` +} + +// Reject is the server's response when auth fails. +type Reject struct { + Version int `json:"version"` + Type MsgType `json:"type"` + Reason string `json:"reason"` +} + +// Errors returned by [Client] and [Server]. +var ( + // ErrRejected wraps a server-side rejection. The reason is in the error message. + ErrRejected = errors.New("handshake rejected") + // ErrProtocolVersion is returned when peer announces an incompatible version. + ErrProtocolVersion = errors.New("incompatible protocol version") + // ErrUnexpectedMessage is returned when a peer sends the wrong message type. + ErrUnexpectedMessage = errors.New("unexpected handshake message") + // ErrFrameTooLarge is returned when a peer announces a frame above [MaxMessageSize]. + ErrFrameTooLarge = errors.New("handshake frame too large") +) + +// AuthFunc is invoked by [Server] after parsing CLIENT_HELLO. +// It returns the session ID to send back to the client, or an error to reject +// the connection. The error's message is forwarded to the client as the +// reject reason, so it should not leak sensitive details. +type AuthFunc func(deviceID string, claims map[string]any) (sessionID string, err error) + +// Client performs the client side of the handshake on rw and returns the +// session ID assigned by the server. +func Client(rw io.ReadWriter, deviceID string, claims map[string]any) (string, error) { + hello := Hello{ + Version: ProtoVersion, + Type: TypeHello, + DeviceID: deviceID, + Claims: claims, + } + if err := writeFrame(rw, hello); err != nil { + return "", fmt.Errorf("send hello: %w", err) + } + + raw, err := readFrame(rw) + if err != nil { + return "", fmt.Errorf("read welcome: %w", err) + } + + var probe struct { + Type MsgType `json:"type"` + } + if err := json.Unmarshal(raw, &probe); err != nil { + return "", fmt.Errorf("parse reply: %w", err) + } + + switch probe.Type { + case TypeWelcome: + var w Welcome + if err := json.Unmarshal(raw, &w); err != nil { + return "", fmt.Errorf("parse welcome: %w", err) + } + if w.Version != ProtoVersion { + return "", fmt.Errorf("%w: server v%d, client v%d", + ErrProtocolVersion, w.Version, ProtoVersion) + } + return w.SessionID, nil + case TypeReject: + var r Reject + if err := json.Unmarshal(raw, &r); err != nil { + return "", fmt.Errorf("parse reject: %w", err) + } + return "", fmt.Errorf("%w: %s", ErrRejected, r.Reason) + default: + return "", fmt.Errorf("%w: got %q", ErrUnexpectedMessage, probe.Type) + } +} + +// Server performs the server side of the handshake. It reads CLIENT_HELLO, +// invokes auth, and writes the corresponding WELCOME or REJECT. On success it +// returns the parsed Hello and the session ID produced by auth. +func Server(rw io.ReadWriter, auth AuthFunc) (Hello, string, error) { + raw, err := readFrame(rw) + if err != nil { + return Hello{}, "", fmt.Errorf("read hello: %w", err) + } + + var h Hello + if err := json.Unmarshal(raw, &h); err != nil { + _ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: "malformed hello"}) + return Hello{}, "", fmt.Errorf("parse hello: %w", err) + } + if h.Type != TypeHello { + _ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: "expected CLIENT_HELLO"}) + return h, "", fmt.Errorf("%w: got %q", ErrUnexpectedMessage, h.Type) + } + if h.Version != ProtoVersion { + _ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: "protocol version mismatch"}) + return h, "", fmt.Errorf("%w: client v%d, server v%d", + ErrProtocolVersion, h.Version, ProtoVersion) + } + + sessionID, err := auth(h.DeviceID, h.Claims) + if err != nil { + _ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: err.Error()}) + return h, "", fmt.Errorf("auth: %w", err) + } + + if err := writeFrame(rw, Welcome{ + Version: ProtoVersion, + Type: TypeWelcome, + SessionID: sessionID, + }); err != nil { + return h, sessionID, fmt.Errorf("send welcome: %w", err) + } + return h, sessionID, nil +} + +func writeFrame(w io.Writer, msg any) error { + body, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + if len(body) > MaxMessageSize { + return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize) + } + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize + if _, err := w.Write(hdr[:]); err != nil { + return fmt.Errorf("write hdr: %w", err) + } + if _, err := w.Write(body); err != nil { + return fmt.Errorf("write body: %w", err) + } + return nil +} + +func readFrame(r io.Reader) ([]byte, error) { + var hdr [4]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return nil, fmt.Errorf("read hdr: %w", err) + } + n := binary.BigEndian.Uint32(hdr[:]) + if n > MaxMessageSize { + return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize) + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, fmt.Errorf("read body: %w", err) + } + return buf, nil +} diff --git a/internal/handshake/handshake_test.go b/internal/handshake/handshake_test.go new file mode 100644 index 0000000..790192b --- /dev/null +++ b/internal/handshake/handshake_test.go @@ -0,0 +1,128 @@ +package handshake + +import ( + "errors" + "io" + "net" + "strings" + "testing" +) + +func pair(t *testing.T) (net.Conn, net.Conn) { + t.Helper() + a, b := net.Pipe() + t.Cleanup(func() { + _ = a.Close() + _ = b.Close() + }) + return a, b +} + +func TestHandshakeRoundTrip(t *testing.T) { + cConn, sConn := pair(t) + + go func() { + hello, sid, err := Server(sConn, func(deviceID string, claims map[string]any) (string, error) { + if deviceID != "dev-1" { + t.Errorf("device id = %q", deviceID) + } + if claims["plan"] != "pro" { + t.Errorf("claims = %v", claims) + } + return "sess-42", nil + }) + if err != nil { + t.Errorf("Server: %v", err) + } + if hello.DeviceID != "dev-1" || sid != "sess-42" { + t.Errorf("Server returned hello=%+v sid=%q", hello, sid) + } + }() + + sid, err := Client(cConn, "dev-1", map[string]any{"plan": "pro"}) + if err != nil { + t.Fatalf("Client: %v", err) + } + if sid != "sess-42" { + t.Fatalf("session id = %q, want sess-42", sid) + } +} + +func TestHandshakeRejected(t *testing.T) { + cConn, sConn := pair(t) + + go func() { + _, _, _ = Server(sConn, func(string, map[string]any) (string, error) { + return "", errors.New("nope") + }) + }() + + _, err := Client(cConn, "dev-1", nil) + if !errors.Is(err, ErrRejected) { + t.Fatalf("Client err = %v, want ErrRejected", err) + } + if !strings.Contains(err.Error(), "nope") { + t.Fatalf("err message %q missing reason", err.Error()) + } +} + +func TestHandshakeProtocolMismatch(t *testing.T) { + cConn, sConn := pair(t) + + go func() { + _ = writeFrame(cConn, Hello{Version: 999, Type: TypeHello, DeviceID: "dev"}) + _, _ = readFrame(cConn) // drain server's REJECT so its write does not block + }() + + _, _, err := Server(sConn, func(string, map[string]any) (string, error) { + t.Fatal("auth must not be invoked on protocol mismatch") + return "", nil + }) + if !errors.Is(err, ErrProtocolVersion) { + t.Fatalf("Server err = %v, want ErrProtocolVersion", err) + } +} + +func TestHandshakeUnexpectedType(t *testing.T) { + cConn, sConn := pair(t) + + go func() { + _ = writeFrame(cConn, Hello{Version: ProtoVersion, Type: "BOGUS", DeviceID: "dev"}) + _, _ = readFrame(cConn) // drain server's REJECT + }() + + _, _, err := Server(sConn, func(string, map[string]any) (string, error) { + t.Fatal("auth must not be invoked on bad type") + return "", nil + }) + if !errors.Is(err, ErrUnexpectedMessage) { + t.Fatalf("Server err = %v, want ErrUnexpectedMessage", err) + } +} + +func TestReadFrameTooLarge(t *testing.T) { + cConn, sConn := pair(t) + + go func() { + var hdr [4]byte + hdr[0] = 0xff + hdr[1] = 0xff + _, _ = cConn.Write(hdr[:]) + _ = cConn.Close() + }() + + _, err := readFrame(sConn) + if !errors.Is(err, ErrFrameTooLarge) { + t.Fatalf("readFrame err = %v, want ErrFrameTooLarge", err) + } +} + +func TestReadFrameEOF(t *testing.T) { + cConn, sConn := pair(t) + _ = cConn.Close() + + _, err := readFrame(sConn) + if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("readFrame err = %v", err) + } +} diff --git a/internal/link/direct/direct.go b/internal/link/direct/direct.go index 26b44fe..4b2aa73 100644 --- a/internal/link/direct/direct.go +++ b/internal/link/direct/direct.go @@ -21,7 +21,7 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) { Engine: cfg.Engine, URL: cfg.URL, Token: cfg.Token, - ClientID: cfg.ClientID, + DeviceID: cfg.DeviceID, Name: cfg.Name, OnData: cfg.OnData, DNSServer: cfg.DNSServer, diff --git a/internal/link/direct/direct_test.go b/internal/link/direct/direct_test.go index bc1f3f0..18edd2e 100644 --- a/internal/link/direct/direct_test.go +++ b/internal/link/direct/direct_test.go @@ -62,7 +62,7 @@ func TestNewForwardsConfigAndMethods(t *testing.T) { Transport: name, Carrier: "carrier", RoomURL: "room", - ClientID: "client", + DeviceID: "client", Name: "peer", DNSServer: "1.1.1.1:53", ProxyAddr: "127.0.0.1", @@ -84,7 +84,7 @@ func TestNewForwardsConfigAndMethods(t *testing.T) { t.Fatalf("New() error = %v", err) } - if seen.ClientID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 { + if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 { t.Fatalf("forwarded config = %+v", seen) } diff --git a/internal/link/link.go b/internal/link/link.go index 9989e51..f094cd0 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -32,7 +32,7 @@ type Config struct { Engine string URL string Token string - ClientID string + DeviceID string Name string OnData func([]byte) DNSServer string diff --git a/internal/link/link_test.go b/internal/link/link_test.go index b53dd38..15260cc 100644 --- a/internal/link/link_test.go +++ b/internal/link/link_test.go @@ -39,11 +39,11 @@ func TestNewAndAvailable(t *testing.T) { called := false Register("test-link", func(_ context.Context, cfg Config) (Link, error) { - called = cfg.ClientID == "client-1" + called = cfg.DeviceID == "client-1" return &stubLink{}, nil }) - got, err := New(context.Background(), "test-link", Config{ClientID: "client-1"}) + got, err := New(context.Background(), "test-link", Config{DeviceID: "client-1"}) if err != nil { t.Fatalf("New() error = %v", err) } diff --git a/internal/server/server.go b/internal/server/server.go index dcf6579..af20c49 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -13,7 +13,9 @@ import ( "sync" "time" + "github.com/google/uuid" "github.com/openlibrecommunity/olcrtc/internal/crypto" + "github.com/openlibrecommunity/olcrtc/internal/handshake" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/muxconn" @@ -43,7 +45,9 @@ type Server struct { sessMu sync.RWMutex reinstallMu sync.Mutex wg sync.WaitGroup - clientID string + authHook handshake.AuthFunc + deviceID string + sessionID string dnsServer string resolver *net.Resolver socksProxyAddr string @@ -52,10 +56,9 @@ type Server struct { // ConnectRequest is a message from the client to establish a new connection. type ConnectRequest struct { - Cmd string `json:"cmd"` - ClientID string `json:"clientId"` - Addr string `json:"addr"` - Port int `json:"port"` + Cmd string `json:"cmd"` + Addr string `json:"addr"` + Port int `json:"port"` } // Config holds runtime configuration for [Run]. @@ -65,7 +68,6 @@ type Config struct { Carrier string RoomURL string KeyHex string - ClientID string DNSServer string SOCKSProxyAddr string SOCKSProxyPort int @@ -88,6 +90,10 @@ type Config struct { Engine string URL string Token string + + // AuthHook is invoked after CLIENT_HELLO to authorize the client and + // return a session ID. If nil, every client is admitted with a random UUID. + AuthHook handshake.AuthFunc } // Run starts the server with the given configuration. @@ -100,9 +106,14 @@ func Run(ctx context.Context, cfg Config) error { return fmt.Errorf("setupCipher failed: %w", err) } + hook := cfg.AuthHook + if hook == nil { + hook = defaultAuthHook + } + s := &Server{ cipher: cipher, - clientID: cfg.ClientID, + authHook: hook, dnsServer: cfg.DNSServer, socksProxyAddr: cfg.SOCKSProxyAddr, socksProxyPort: cfg.SOCKSProxyPort, @@ -182,7 +193,7 @@ func (s *Server) bringUpLink( Engine: cfg.Engine, URL: cfg.URL, Token: cfg.Token, - ClientID: s.clientID, + DeviceID: "", Name: names.Generate(), OnData: s.onData, DNSServer: s.dnsServer, @@ -270,6 +281,8 @@ func (s *Server) reinstallSession(dead *smux.Session) { _ = s.conn.Close() s.conn = nil } + s.sessionID = "" + s.deviceID = "" s.sessMu.Unlock() s.installSession() } @@ -284,6 +297,8 @@ func (s *Server) closeSession() { _ = s.conn.Close() s.conn = nil } + s.sessionID = "" + s.deviceID = "" s.sessMu.Unlock() } @@ -296,9 +311,9 @@ func (s *Server) onData(data []byte) { } } -// serve drives the smux Accept loop, spawning a tunnel per inbound stream. -// The loop tolerates session bounces (reconnects) by waiting until a fresh -// session is installed instead of terminating the server. +// serve drives the smux Accept loop. The first accepted stream on a given +// smux session is the control stream — the handshake runs there. Subsequent +// streams are tunnel streams and proxy traffic. func (s *Server) serve(ctx context.Context) { for { select { @@ -319,6 +334,12 @@ func (s *Server) serve(ctx context.Context) { } } + if !s.handshakeReady() { + if !s.acceptHandshake(ctx, sess) { + continue + } + } + stream, err := sess.AcceptStream() if err != nil { select { @@ -339,6 +360,62 @@ func (s *Server) serve(ctx context.Context) { } } +// handshakeReady reports whether the current session has completed its +// handshake. The session is reset on reconnect, so this is recomputed. +func (s *Server) handshakeReady() bool { + s.sessMu.RLock() + defer s.sessMu.RUnlock() + return s.sessionID != "" +} + +func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool { + stream, err := sess.AcceptStream() + if err != nil { + select { + case <-ctx.Done(): + return false + default: + } + logger.Debugf("AcceptStream(control) returned %v - reinstalling session", err) + s.reinstallSession(sess) + return false + } + _ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout)) + hello, sid, err := handshake.Server(stream, s.authHook) + _ = stream.SetDeadline(time.Time{}) + if err != nil { + logger.Warnf("handshake failed: %v", err) + _ = stream.Close() + s.reinstallSession(sess) + return false + } + s.sessMu.Lock() + s.deviceID = hello.DeviceID + s.sessionID = sid + s.sessMu.Unlock() + logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID) + // The control stream stays open for the lifetime of the session; + // keep it parked in a goroutine so the smux session does not close it. + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.parkControlStream(stream) + }() + return true +} + +// parkControlStream blocks reading from the control stream until it closes. +// Future control messages (kick, rate updates, etc.) would be dispatched here. +func (s *Server) parkControlStream(stream *smux.Stream) { + defer func() { _ = stream.Close() }() + buf := make([]byte, 64) + for { + if _, err := stream.Read(buf); err != nil { + return + } + } +} + func (s *Server) shutdown() { s.closeSession() if s.ln != nil { @@ -362,10 +439,6 @@ func (s *Server) handleStream(_ context.Context, stream *smux.Stream) { header = append(header, tmp[:n]...) if req, ok := parseConnectRequest(header); ok { _ = stream.SetReadDeadline(time.Time{}) - if !s.authorizeRequest(req) { - logger.Warnf("sid=%d rejected: client_id mismatch", stream.ID()) - return - } s.dispatch(stream, req) return } @@ -390,8 +463,10 @@ func parseConnectRequest(buf []byte) (ConnectRequest, bool) { return req, true } -func (s *Server) authorizeRequest(req ConnectRequest) bool { - return req.ClientID == s.clientID +// defaultAuthHook admits every client and assigns a random session ID. +// Replace it via [Config.AuthHook] to plug in real authorization. +func defaultAuthHook(_ string, _ map[string]any) (string, error) { + return uuid.NewString(), nil } func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 26dbd67..1414c68 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -47,10 +47,9 @@ func TestSmuxConfig(t *testing.T) { func TestParseConnectRequest(t *testing.T) { buf, err := json.Marshal(ConnectRequest{ - Cmd: "connect", - ClientID: "client-1", //nolint:goconst // test literal, repetition is intentional - Addr: "example.com", //nolint:goconst // test literal, repetition is intentional - Port: 443, + Cmd: "connect", + Addr: "example.com", //nolint:goconst // test literal, repetition is intentional + Port: 443, }) if err != nil { t.Fatalf("Marshal() error = %v", err) @@ -60,7 +59,7 @@ func TestParseConnectRequest(t *testing.T) { if !ok { t.Fatal("parseConnectRequest() returned ok=false") } - if req.ClientID != "client-1" || req.Addr != "example.com" || req.Port != 443 { + if req.Addr != "example.com" || req.Port != 443 { t.Fatalf("parseConnectRequest() = %+v", req) } @@ -72,13 +71,13 @@ func TestParseConnectRequest(t *testing.T) { } } -func TestAuthorizeRequest(t *testing.T) { - s := &Server{clientID: "client-1"} - if !s.authorizeRequest(ConnectRequest{ClientID: "client-1"}) { - t.Fatal("authorizeRequest() rejected valid client") +func TestDefaultAuthHook(t *testing.T) { + sid, err := defaultAuthHook("dev", map[string]any{"x": 1}) + if err != nil { + t.Fatalf("defaultAuthHook() err = %v", err) } - if s.authorizeRequest(ConnectRequest{ClientID: "client-2"}) { - t.Fatal("authorizeRequest() accepted wrong client") + if sid == "" { + t.Fatal("defaultAuthHook() returned empty session id") } } @@ -301,7 +300,7 @@ func TestSocks5ConnectTruncatesLongDomain(t *testing.T) { } } -func TestHandleStreamRejectsWrongClientID(t *testing.T) { +func TestHandleStreamDispatchAfterConnect(t *testing.T) { a, b := net.Pipe() defer func() { _ = a.Close() @@ -323,7 +322,7 @@ func TestHandleStreamRejectsWrongClientID(t *testing.T) { go func() { stream, err := serverSess.AcceptStream() if err == nil { - (&Server{clientID: "expected"}).handleStream(context.Background(), stream) + (&Server{}).handleStream(context.Background(), stream) } close(done) }() @@ -333,10 +332,9 @@ func TestHandleStreamRejectsWrongClientID(t *testing.T) { t.Fatalf("OpenStream() error = %v", err) } req, err := json.Marshal(ConnectRequest{ - Cmd: "connect", - ClientID: "wrong", - Addr: "example.com", - Port: 443, + Cmd: "connect", + Addr: "127.0.0.1", + Port: 1, // unreachable port — dispatch will fail dial and exit }) if err != nil { t.Fatalf("Marshal() error = %v", err) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 90153a2..9e11240 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -41,7 +41,7 @@ type Config struct { Engine string URL string Token string - ClientID string + DeviceID string Name string OnData func([]byte) DNSServer string diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 6330b6a..dfa2683 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -40,11 +40,11 @@ func TestNewAndAvailable(t *testing.T) { called := false Register("test-transport", func(_ context.Context, cfg Config) (Transport, error) { - called = cfg.ClientID == "client-1" + called = cfg.DeviceID == "client-1" return &stubTransport{}, nil }) - got, err := New(context.Background(), "test-transport", Config{ClientID: "client-1"}) + got, err := New(context.Background(), "test-transport", Config{DeviceID: "client-1"}) if err != nil { t.Fatalf("New() error = %v", err) } diff --git a/internal/transport/vp8channel/transport.go b/internal/transport/vp8channel/transport.go index 13875b3..d46cc73 100644 --- a/internal/transport/vp8channel/transport.go +++ b/internal/transport/vp8channel/transport.go @@ -162,7 +162,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) writerDone: make(chan struct{}), frameInterval: time.Second / time.Duration(fps), batchSize: batchSize, - bindingToken: bindingToken(cfg.ClientID), + bindingToken: bindingToken(cfg.DeviceID), localEpoch: randomEpoch(), } diff --git a/internal/transport/vp8channel/transport_unit_test.go b/internal/transport/vp8channel/transport_unit_test.go index bc506c5..e40d86e 100644 --- a/internal/transport/vp8channel/transport_unit_test.go +++ b/internal/transport/vp8channel/transport_unit_test.go @@ -92,7 +92,7 @@ func TestNewConnectSendCallbacksFeaturesAndClose(t *testing.T) { trIface, err := New(context.Background(), transport.Config{ Carrier: name, - ClientID: "client", + DeviceID: "client", VP8FPS: 30, VP8BatchSize: 1, }) diff --git a/mobile/mobile.go b/mobile/mobile.go index c1e3798..0cf1a55 100644 --- a/mobile/mobile.go +++ b/mobile/mobile.go @@ -222,7 +222,7 @@ func Check( Carrier: carrierName, RoomURL: buildRoomURL(carrierName, roomID), KeyHex: keyHex, - ClientID: clientID, + DeviceID: clientID, LocalAddr: fmt.Sprintf("127.0.0.1:%d", socksPort), DNSServer: defaultDNSServer, VP8FPS: clampAtLeastOne(vp8FPS, 120), @@ -305,7 +305,7 @@ func Ping( Carrier: carrierName, RoomURL: buildRoomURL(carrierName, roomID), KeyHex: keyHex, - ClientID: clientID, + DeviceID: clientID, LocalAddr: fmt.Sprintf("127.0.0.1:%d", socksPort), DNSServer: defaultDNSServer, VP8FPS: clampAtLeastOne(vp8FPS, 120), @@ -550,7 +550,7 @@ func startWithConfig( Carrier: carrierName, RoomURL: roomURL, KeyHex: keyHex, - ClientID: clientID, + DeviceID: clientID, LocalAddr: fmt.Sprintf("127.0.0.1:%d", socksPort), DNSServer: cfg.dnsServer, SOCKSUser: socksUser, diff --git a/mobile/mobile_test.go b/mobile/mobile_test.go index 8b635c1..541fba5 100644 --- a/mobile/mobile_test.go +++ b/mobile/mobile_test.go @@ -171,10 +171,10 @@ func TestStartWithInjectedRunnerLifecycle(t *testing.T) { runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error { if cfg.Link != defaultLink || cfg.Transport != dataTransport || cfg.Carrier != carrierJazz || - cfg.RoomURL != "any" || cfg.ClientID != "client" || cfg.LocalAddr != "127.0.0.1:1080" || + cfg.RoomURL != "any" || cfg.DeviceID != "client" || cfg.LocalAddr != "127.0.0.1:1080" || cfg.DNSServer != defaultDNSServer || cfg.VP8FPS != 60 || cfg.VP8BatchSize != 8 { t.Fatalf("RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d", - cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.ClientID, cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize) + cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.DeviceID, cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize) } onReady() <-ctx.Done()