feat: remove unused client ID from config

This commit is contained in:
zarazaex69
2026-05-13 20:03:58 +03:00
parent d1cc68c64a
commit bcc6b2ee5c
22 changed files with 600 additions and 156 deletions

View File

@@ -83,7 +83,6 @@ func TestRunWithConfigValidationAndDataDirErrors(t *testing.T) {
Link: "direct",
Transport: "datachannel",
Auth: "jazz",
ClientID: "client",
KeyHex: "key",
DNSServer: "1.1.1.1:53",
}
@@ -113,7 +112,7 @@ func TestRunWithArgsSuccessfulSessionReturn(t *testing.T) {
called := false
runSession = func(ctx context.Context, cfg session.Config) error {
called = true
if cfg.Mode != "srv" || cfg.Auth != "jazz" || cfg.ClientID != "client" {
if cfg.Mode != "srv" || cfg.Auth != "jazz" {
t.Fatalf("session config = %+v", cfg)
}
select {
@@ -129,8 +128,6 @@ mode: srv
link: direct
auth:
provider: jazz
room:
client_id: client
crypto:
key: key
net:

View File

@@ -101,8 +101,6 @@ var (
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)")
// ErrSOCKSPortRequired indicates that socks port is required for cnc mode.
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)")
// ErrClientIDRequired indicates that client ID is required.
ErrClientIDRequired = errors.New("client ID required (use -client-id <id>)")
)
// Config holds runtime session settings.
@@ -115,7 +113,6 @@ type Config struct {
URL string
Token string
RoomID string
ClientID string
KeyHex string
SOCKSHost string
SOCKSPort int
@@ -242,9 +239,6 @@ func validateCommon(cfg Config) error {
if cfg.RoomID == "" && cfg.Auth != authJazz && cfg.Auth != authNone {
return ErrRoomIDRequired
}
if cfg.ClientID == "" {
return ErrClientIDRequired
}
if cfg.KeyHex == "" {
return ErrKeyRequired
}

View File

@@ -16,7 +16,6 @@ func TestValidate(t *testing.T) {
Transport: "datachannel",
Auth: "telemost",
RoomID: "room-1",
ClientID: "client-1",
KeyHex: "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff",
DNSServer: "1.1.1.1:53", //nolint:goconst // test literal, repetition is intentional
}
@@ -91,15 +90,6 @@ func TestValidate(t *testing.T) {
}(),
want: ErrRoomIDRequired,
},
{
name: "client id required",
cfg: func() Config {
cfg := base
cfg.ClientID = ""
return cfg
}(),
want: ErrClientIDRequired,
},
{
name: "key required",
cfg: func() Config {

View File

@@ -10,10 +10,15 @@ import (
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
@@ -44,15 +49,18 @@ var (
// Client handles local SOCKS5 connections and tunnels them to the server.
type Client struct {
ln link.Link
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
sessMu sync.RWMutex
clientID string
dnsServer string
socksUser string
socksPass string
ln link.Link
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
controlStrm *smux.Stream
sessMu sync.RWMutex
deviceID string
sessionID string
claims map[string]any
dnsServer string
socksUser string
socksPass string
}
// Config holds runtime configuration for [Run] and [RunWithReady].
@@ -62,7 +70,6 @@ type Config struct {
Carrier string
RoomURL string
KeyHex string
ClientID string
LocalAddr string
DNSServer string
SOCKSUser string
@@ -86,6 +93,19 @@ type Config struct {
Engine string
URL string
Token string
// DeviceID overrides the persistent client-side device identifier. Leave
// empty to derive one from DeviceIDPath (or generate a random one if both
// are empty).
DeviceID string
// DeviceIDPath is a file in which to persist the auto-generated device ID
// across restarts. Ignored when DeviceID is set explicitly.
DeviceIDPath string
// Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to
// the server's AuthHook. Free-form key/value bag for plan, user, region, etc.
Claims map[string]any
}
// Run starts the client with the given configuration.
@@ -103,9 +123,15 @@ func RunWithReady(ctx context.Context, cfg Config, onReady func()) error {
return fmt.Errorf("setupCipher failed: %w", err)
}
deviceID, err := resolveDeviceID(cfg.DeviceID, cfg.DeviceIDPath)
if err != nil {
return fmt.Errorf("resolve device id: %w", err)
}
c := &Client{
cipher: cipher,
clientID: cfg.ClientID,
deviceID: deviceID,
claims: cfg.Claims,
dnsServer: cfg.DNSServer,
socksUser: cfg.SOCKSUser,
socksPass: cfg.SOCKSPass,
@@ -147,7 +173,7 @@ func (c *Client) bringUpLink(
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
ClientID: c.clientID,
DeviceID: c.deviceID,
Name: names.Generate(),
OnData: c.onData,
DNSServer: cfg.DNSServer,
@@ -188,14 +214,80 @@ func (c *Client) bringUpLink(
if err != nil {
return fmt.Errorf("smux client: %w", err)
}
control, sid, err := openControlStream(sess, c.deviceID, c.claims)
if err != nil {
_ = sess.Close()
_ = c.conn.Close()
return fmt.Errorf("handshake: %w", err)
}
logger.Infof("session %s opened (device=%s)", sid, c.deviceID)
c.sessMu.Lock()
c.session = sess
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
go ln.WatchConnection(ctx)
return nil
}
// openControlStream opens stream #1 on sess and performs the handshake.
// The stream stays open for the lifetime of the smux session — the server
// holds it parked, and it would carry future control messages.
func openControlStream(
sess *smux.Session,
deviceID string,
claims map[string]any,
) (*smux.Stream, string, error) {
stream, err := sess.OpenStream()
if err != nil {
return nil, "", fmt.Errorf("open control stream: %w", err)
}
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
sid, err := handshake.Client(stream, deviceID, claims)
_ = stream.SetDeadline(time.Time{})
if err != nil {
_ = stream.Close()
return nil, "", err
}
return stream, sid, nil
}
// resolveDeviceID returns the device ID to send in CLIENT_HELLO.
//
// Precedence:
// 1. Explicit deviceID arg (Config.DeviceID) — used verbatim.
// 2. Persistent file at path (Config.DeviceIDPath) — read if it exists,
// otherwise generated and written for future runs.
// 3. Random UUID per run when both inputs are empty.
func resolveDeviceID(deviceID, path string) (string, error) {
if deviceID != "" {
return deviceID, nil
}
if path == "" {
return uuid.NewString(), nil
}
data, err := os.ReadFile(path)
if err == nil {
id := strings.TrimSpace(string(data))
if id != "" {
return id, nil
}
} else if !errors.Is(err, os.ErrNotExist) {
return "", fmt.Errorf("read device id %s: %w", path, err)
}
id := uuid.NewString()
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return "", fmt.Errorf("mkdir device id dir: %w", err)
}
if err := os.WriteFile(path, []byte(id+"\n"), 0o600); err != nil {
return "", fmt.Errorf("write device id %s: %w", path, err)
}
return id, nil
}
// smuxConfig returns the tuned smux config used on both ends.
func smuxConfig() *smux.Config {
cfg := smux.DefaultConfig()
@@ -212,6 +304,10 @@ func smuxConfig() *smux.Config {
func (c *Client) handleReconnect() {
logger.Infof("client link reconnect - tearing down smux session")
c.sessMu.Lock()
if c.controlStrm != nil {
_ = c.controlStrm.Close()
c.controlStrm = nil
}
if c.session != nil {
_ = c.session.Close()
c.session = nil
@@ -220,6 +316,7 @@ func (c *Client) handleReconnect() {
_ = c.conn.Close()
c.conn = nil
}
c.sessionID = ""
c.sessMu.Unlock()
c.conn = muxconn.New(c.ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig())
@@ -227,13 +324,25 @@ func (c *Client) handleReconnect() {
logger.Warnf("smux re-init failed: %v", err)
return
}
control, sid, err := openControlStream(sess, c.deviceID, c.claims)
if err != nil {
logger.Warnf("handshake on reconnect failed: %v", err)
_ = sess.Close()
return
}
logger.Infof("session %s reopened (device=%s)", sid, c.deviceID)
c.sessMu.Lock()
c.session = sess
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
}
func (c *Client) shutdown() {
c.sessMu.Lock()
if c.controlStrm != nil {
_ = c.controlStrm.Close()
}
if c.session != nil {
_ = c.session.Close()
}
@@ -340,10 +449,9 @@ func (c *Client) tunnel(conn net.Conn, sess *smux.Session, targetAddr string, ta
func (c *Client) sendConnectRequest(stream *smux.Stream, targetAddr string, targetPort int) error {
connectReq, err := json.Marshal(map[string]any{
"cmd": "connect",
"clientId": c.clientID,
"addr": targetAddr,
"port": targetPort,
"cmd": "connect",
"addr": targetAddr,
"port": targetPort,
})
if err != nil {
return fmt.Errorf("sid=%d marshal connect req: %w", stream.ID(), err)

View File

@@ -417,7 +417,7 @@ func TestSendConnectRequestOverSmux(t *testing.T) {
done <- err
return
}
if req["cmd"] != "connect" || req["clientId"] != "client-1" || req["addr"] != "example.com" { //nolint:goconst,lll // test literal, repetition is intentional
if req["cmd"] != "connect" || req["addr"] != "example.com" { //nolint:goconst,lll // test literal, repetition is intentional
done <- errUnexpectedConnectRequest
return
}
@@ -431,7 +431,7 @@ func TestSendConnectRequestOverSmux(t *testing.T) {
}
defer func() { _ = stream.Close() }()
c := &Client{clientID: "client-1"}
c := &Client{deviceID: "client-1"}
if err := c.sendConnectRequest(stream, "example.com", 443); err != nil {
t.Fatalf("sendConnectRequest() error = %v", err)
}
@@ -473,7 +473,7 @@ func TestSendConnectRequestRejectsBadAck(t *testing.T) {
}
defer func() { _ = stream.Close() }()
c := &Client{clientID: "client-1"}
c := &Client{deviceID: "client-1"}
if err := c.sendConnectRequest(stream, "example.com", 443); !errors.Is(err, ErrRemoteNotReady) {
t.Fatalf("sendConnectRequest() error = %v, want %v", err, ErrRemoteNotReady)
}

View File

@@ -45,8 +45,7 @@ type Auth struct {
// Room identifies the conference room.
type Room struct {
ID string `yaml:"id"`
ClientID string `yaml:"client_id"` // deprecated: server identifier (will be removed)
ID string `yaml:"id"`
}
// Crypto holds the shared secret used to authenticate and encrypt the tunnel.
@@ -137,7 +136,6 @@ func Apply(dst session.Config, f File) session.Config {
dst.URL = pickString(dst.URL, f.Engine.URL)
dst.Token = pickString(dst.Token, f.Engine.Token)
dst.RoomID = pickString(dst.RoomID, f.Room.ID)
dst.ClientID = pickString(dst.ClientID, f.Room.ClientID)
dst.KeyHex = pickString(dst.KeyHex, f.Crypto.Key)
dst.SOCKSHost = pickString(dst.SOCKSHost, f.SOCKS.Host)
dst.SOCKSPort = pickInt(dst.SOCKSPort, f.SOCKS.Port)

View File

@@ -18,7 +18,6 @@ auth:
provider: wbstream
room:
id: r1
client_id: c1
crypto:
key: deadbeef
net:
@@ -50,7 +49,7 @@ debug: true
got := Apply(session.Config{}, f)
if got.Mode != "srv" || got.Link != "direct" || got.Auth != "wbstream" ||
got.RoomID != "r1" || got.ClientID != "c1" || got.KeyHex != "deadbeef" ||
got.RoomID != "r1" || got.KeyHex != "deadbeef" ||
got.Transport != "datachannel" || got.DNSServer != "1.1.1.1:53" ||
got.SOCKSHost != "127.0.0.1" || got.SOCKSPort != 1080 ||
got.SOCKSUser != "u" || got.SOCKSPass != "p" ||

View File

@@ -400,7 +400,6 @@ func validSessionConfig(mode, carrierName, transportName string) session.Config
Transport: transportName,
Auth: carrierName,
RoomID: "room",
ClientID: "client-1",
KeyHex: testKeyHex,
SOCKSHost: "127.0.0.1",
SOCKSPort: 1080,
@@ -428,7 +427,7 @@ func validLinkConfig(carrierName, transportName string) link.Config {
Transport: cfg.Transport,
Carrier: cfg.Auth,
RoomURL: "room",
ClientID: cfg.ClientID,
DeviceID: "e2e-link-test",
Name: "e2e-" + carrierName + "-" + transportName,
DNSServer: cfg.DNSServer,
VideoWidth: cfg.VideoWidth,
@@ -505,7 +504,7 @@ type tunnelRuntime struct {
clientErr chan error
}
func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRuntime {
func startTunnel(t *testing.T, deviceID, _ string) *tunnelRuntime {
t.Helper()
carrierName, room := registerMemoryCarrier(t)
@@ -521,7 +520,6 @@ func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRun
Carrier: carrierName,
RoomURL: "room",
KeyHex: testKeyHex,
ClientID: serverClientID,
DNSServer: "127.0.0.1:53",
})
}()
@@ -536,7 +534,7 @@ func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRun
Carrier: carrierName,
RoomURL: "room",
KeyHex: testKeyHex,
ClientID: clientClientID,
DeviceID: deviceID,
LocalAddr: socksAddr,
DNSServer: "127.0.0.1:53",
}, func() { close(ready) })
@@ -555,7 +553,7 @@ func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRun
func startRealTunnel(
ctx context.Context,
t *testing.T,
carrierName, transportName, roomURL, serverClientID, clientClientID string,
carrierName, transportName, roomURL, _, clientDeviceID string,
) (*tunnelRuntime, error) {
t.Helper()
@@ -573,7 +571,6 @@ func startRealTunnel(
Carrier: carrierName,
RoomURL: roomURL,
KeyHex: testKeyHex,
ClientID: serverClientID,
DNSServer: "127.0.0.1:53",
VideoWidth: 1080,
VideoHeight: 1080,
@@ -613,7 +610,7 @@ func startRealTunnel(
Carrier: carrierName,
RoomURL: roomURL,
KeyHex: testKeyHex,
ClientID: clientClientID,
DeviceID: clientDeviceID,
LocalAddr: socksAddr,
DNSServer: "127.0.0.1:53",
VideoWidth: 1080,
@@ -749,49 +746,6 @@ func connectViaSOCKS(t *testing.T, socksAddr, targetAddr string) net.Conn {
return conn
}
func connectViaSOCKSExpectFailure(t *testing.T, socksAddr, targetAddr string) []byte {
t.Helper()
dialer := net.Dialer{Timeout: 2 * time.Second}
conn, err := dialer.DialContext(context.Background(), "tcp4", socksAddr)
if err != nil {
t.Fatalf("dial socks: %v", err)
}
defer func() { _ = conn.Close() }()
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
t.Fatalf("write socks greeting: %v", err)
}
greeting := make([]byte, 2)
if _, err := io.ReadFull(conn, greeting); err != nil {
t.Fatalf("read socks greeting: %v", err)
}
host, portText, err := net.SplitHostPort(targetAddr)
if err != nil {
t.Fatalf("split target addr: %v", err)
}
port, err := strconv.Atoi(portText)
if err != nil {
t.Fatalf("parse target port: %v", err)
}
req := make([]byte, 0, 10)
req = append(req, 5, 1, 0, 1)
req = append(req, net.ParseIP(host).To4()...)
var portBuf [2]byte
binary.BigEndian.PutUint16(portBuf[:], uint16(port)) //nolint:gosec // SOCKS5 port is uint16 by definition
req = append(req, portBuf[:]...)
if _, err := conn.Write(req); err != nil {
t.Fatalf("write socks connect: %v", err)
}
reply := make([]byte, 10)
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read socks failure reply: %v", err)
}
return reply
}
func TestBuiltInProviderTransportMatrixValidates(t *testing.T) {
session.RegisterDefaults()
@@ -971,17 +925,6 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) {
}
}
func TestWrongClientIDIsRejected(t *testing.T) {
echoAddr := startEchoServer(t)
rt := startTunnel(t, "server-client", "wrong-client")
defer rt.stop(t)
reply := connectViaSOCKSExpectFailure(t, rt.socksAddr, echoAddr)
if !bytes.Equal(reply, []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}) {
t.Fatalf("wrong client-id reply = %v, want host unreachable", reply)
}
}
func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) {
echoAddr := startEchoServer(t)
rt := startTunnel(t, "client-1", "client-1")

View File

@@ -0,0 +1,214 @@
// Package handshake implements the olcrtc session handshake.
//
// The handshake runs on the first smux stream (control stream) of a tunnel.
// Wire format on the control stream is length-prefixed JSON: each message is
// a 4-byte big-endian length followed by that many bytes of JSON.
//
// client server
// │ CLIENT_HELLO │
// │ ─────────────────────► │
// │ │ AuthHook(claims) → sessionID | err
// │ SERVER_WELCOME / REJECT│
// │ ◄───────────────────── │
// │ │
//
// After the exchange the control stream stays open; tunnel traffic flows over
// additional smux streams opened by the client. The control stream may carry
// keepalives or future control messages.
package handshake
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"time"
)
// ProtoVersion identifies the wire-format version. Bumped only on breaking
// changes to message layout or semantics.
const ProtoVersion = 1
// MaxMessageSize caps a single handshake frame. 64 KiB is comfortably larger
// than any legitimate HELLO/WELCOME payload and prevents memory blowups from
// malicious peers.
const MaxMessageSize = 64 * 1024
// DefaultTimeout bounds how long either side will wait for the peer's reply
// before bailing out.
const DefaultTimeout = 15 * time.Second
// MsgType labels each protocol message.
type MsgType string
const (
// TypeHello is the client's first message.
TypeHello MsgType = "CLIENT_HELLO"
// TypeWelcome is the server's success reply.
TypeWelcome MsgType = "SERVER_WELCOME"
// TypeReject is the server's failure reply.
TypeReject MsgType = "SERVER_REJECT"
)
// Hello is sent by the client to begin a session.
type Hello struct {
Version int `json:"version"`
Type MsgType `json:"type"`
DeviceID string `json:"device_id"`
Claims map[string]any `json:"claims,omitempty"`
}
// Welcome is the server's response on a successful handshake.
type Welcome struct {
Version int `json:"version"`
Type MsgType `json:"type"`
SessionID string `json:"session_id"`
}
// Reject is the server's response when auth fails.
type Reject struct {
Version int `json:"version"`
Type MsgType `json:"type"`
Reason string `json:"reason"`
}
// Errors returned by [Client] and [Server].
var (
// ErrRejected wraps a server-side rejection. The reason is in the error message.
ErrRejected = errors.New("handshake rejected")
// ErrProtocolVersion is returned when peer announces an incompatible version.
ErrProtocolVersion = errors.New("incompatible protocol version")
// ErrUnexpectedMessage is returned when a peer sends the wrong message type.
ErrUnexpectedMessage = errors.New("unexpected handshake message")
// ErrFrameTooLarge is returned when a peer announces a frame above [MaxMessageSize].
ErrFrameTooLarge = errors.New("handshake frame too large")
)
// AuthFunc is invoked by [Server] after parsing CLIENT_HELLO.
// It returns the session ID to send back to the client, or an error to reject
// the connection. The error's message is forwarded to the client as the
// reject reason, so it should not leak sensitive details.
type AuthFunc func(deviceID string, claims map[string]any) (sessionID string, err error)
// Client performs the client side of the handshake on rw and returns the
// session ID assigned by the server.
func Client(rw io.ReadWriter, deviceID string, claims map[string]any) (string, error) {
hello := Hello{
Version: ProtoVersion,
Type: TypeHello,
DeviceID: deviceID,
Claims: claims,
}
if err := writeFrame(rw, hello); err != nil {
return "", fmt.Errorf("send hello: %w", err)
}
raw, err := readFrame(rw)
if err != nil {
return "", fmt.Errorf("read welcome: %w", err)
}
var probe struct {
Type MsgType `json:"type"`
}
if err := json.Unmarshal(raw, &probe); err != nil {
return "", fmt.Errorf("parse reply: %w", err)
}
switch probe.Type {
case TypeWelcome:
var w Welcome
if err := json.Unmarshal(raw, &w); err != nil {
return "", fmt.Errorf("parse welcome: %w", err)
}
if w.Version != ProtoVersion {
return "", fmt.Errorf("%w: server v%d, client v%d",
ErrProtocolVersion, w.Version, ProtoVersion)
}
return w.SessionID, nil
case TypeReject:
var r Reject
if err := json.Unmarshal(raw, &r); err != nil {
return "", fmt.Errorf("parse reject: %w", err)
}
return "", fmt.Errorf("%w: %s", ErrRejected, r.Reason)
default:
return "", fmt.Errorf("%w: got %q", ErrUnexpectedMessage, probe.Type)
}
}
// Server performs the server side of the handshake. It reads CLIENT_HELLO,
// invokes auth, and writes the corresponding WELCOME or REJECT. On success it
// returns the parsed Hello and the session ID produced by auth.
func Server(rw io.ReadWriter, auth AuthFunc) (Hello, string, error) {
raw, err := readFrame(rw)
if err != nil {
return Hello{}, "", fmt.Errorf("read hello: %w", err)
}
var h Hello
if err := json.Unmarshal(raw, &h); err != nil {
_ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: "malformed hello"})
return Hello{}, "", fmt.Errorf("parse hello: %w", err)
}
if h.Type != TypeHello {
_ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: "expected CLIENT_HELLO"})
return h, "", fmt.Errorf("%w: got %q", ErrUnexpectedMessage, h.Type)
}
if h.Version != ProtoVersion {
_ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: "protocol version mismatch"})
return h, "", fmt.Errorf("%w: client v%d, server v%d",
ErrProtocolVersion, h.Version, ProtoVersion)
}
sessionID, err := auth(h.DeviceID, h.Claims)
if err != nil {
_ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: err.Error()})
return h, "", fmt.Errorf("auth: %w", err)
}
if err := writeFrame(rw, Welcome{
Version: ProtoVersion,
Type: TypeWelcome,
SessionID: sessionID,
}); err != nil {
return h, sessionID, fmt.Errorf("send welcome: %w", err)
}
return h, sessionID, nil
}
func writeFrame(w io.Writer, msg any) error {
body, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
if len(body) > MaxMessageSize {
return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize)
}
var hdr [4]byte
binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize
if _, err := w.Write(hdr[:]); err != nil {
return fmt.Errorf("write hdr: %w", err)
}
if _, err := w.Write(body); err != nil {
return fmt.Errorf("write body: %w", err)
}
return nil
}
func readFrame(r io.Reader) ([]byte, error) {
var hdr [4]byte
if _, err := io.ReadFull(r, hdr[:]); err != nil {
return nil, fmt.Errorf("read hdr: %w", err)
}
n := binary.BigEndian.Uint32(hdr[:])
if n > MaxMessageSize {
return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize)
}
buf := make([]byte, n)
if _, err := io.ReadFull(r, buf); err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
return buf, nil
}

View File

@@ -0,0 +1,128 @@
package handshake
import (
"errors"
"io"
"net"
"strings"
"testing"
)
func pair(t *testing.T) (net.Conn, net.Conn) {
t.Helper()
a, b := net.Pipe()
t.Cleanup(func() {
_ = a.Close()
_ = b.Close()
})
return a, b
}
func TestHandshakeRoundTrip(t *testing.T) {
cConn, sConn := pair(t)
go func() {
hello, sid, err := Server(sConn, func(deviceID string, claims map[string]any) (string, error) {
if deviceID != "dev-1" {
t.Errorf("device id = %q", deviceID)
}
if claims["plan"] != "pro" {
t.Errorf("claims = %v", claims)
}
return "sess-42", nil
})
if err != nil {
t.Errorf("Server: %v", err)
}
if hello.DeviceID != "dev-1" || sid != "sess-42" {
t.Errorf("Server returned hello=%+v sid=%q", hello, sid)
}
}()
sid, err := Client(cConn, "dev-1", map[string]any{"plan": "pro"})
if err != nil {
t.Fatalf("Client: %v", err)
}
if sid != "sess-42" {
t.Fatalf("session id = %q, want sess-42", sid)
}
}
func TestHandshakeRejected(t *testing.T) {
cConn, sConn := pair(t)
go func() {
_, _, _ = Server(sConn, func(string, map[string]any) (string, error) {
return "", errors.New("nope")
})
}()
_, err := Client(cConn, "dev-1", nil)
if !errors.Is(err, ErrRejected) {
t.Fatalf("Client err = %v, want ErrRejected", err)
}
if !strings.Contains(err.Error(), "nope") {
t.Fatalf("err message %q missing reason", err.Error())
}
}
func TestHandshakeProtocolMismatch(t *testing.T) {
cConn, sConn := pair(t)
go func() {
_ = writeFrame(cConn, Hello{Version: 999, Type: TypeHello, DeviceID: "dev"})
_, _ = readFrame(cConn) // drain server's REJECT so its write does not block
}()
_, _, err := Server(sConn, func(string, map[string]any) (string, error) {
t.Fatal("auth must not be invoked on protocol mismatch")
return "", nil
})
if !errors.Is(err, ErrProtocolVersion) {
t.Fatalf("Server err = %v, want ErrProtocolVersion", err)
}
}
func TestHandshakeUnexpectedType(t *testing.T) {
cConn, sConn := pair(t)
go func() {
_ = writeFrame(cConn, Hello{Version: ProtoVersion, Type: "BOGUS", DeviceID: "dev"})
_, _ = readFrame(cConn) // drain server's REJECT
}()
_, _, err := Server(sConn, func(string, map[string]any) (string, error) {
t.Fatal("auth must not be invoked on bad type")
return "", nil
})
if !errors.Is(err, ErrUnexpectedMessage) {
t.Fatalf("Server err = %v, want ErrUnexpectedMessage", err)
}
}
func TestReadFrameTooLarge(t *testing.T) {
cConn, sConn := pair(t)
go func() {
var hdr [4]byte
hdr[0] = 0xff
hdr[1] = 0xff
_, _ = cConn.Write(hdr[:])
_ = cConn.Close()
}()
_, err := readFrame(sConn)
if !errors.Is(err, ErrFrameTooLarge) {
t.Fatalf("readFrame err = %v, want ErrFrameTooLarge", err)
}
}
func TestReadFrameEOF(t *testing.T) {
cConn, sConn := pair(t)
_ = cConn.Close()
_, err := readFrame(sConn)
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("readFrame err = %v", err)
}
}

View File

@@ -21,7 +21,7 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) {
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
ClientID: cfg.ClientID,
DeviceID: cfg.DeviceID,
Name: cfg.Name,
OnData: cfg.OnData,
DNSServer: cfg.DNSServer,

View File

@@ -62,7 +62,7 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
Transport: name,
Carrier: "carrier",
RoomURL: "room",
ClientID: "client",
DeviceID: "client",
Name: "peer",
DNSServer: "1.1.1.1:53",
ProxyAddr: "127.0.0.1",
@@ -84,7 +84,7 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
t.Fatalf("New() error = %v", err)
}
if seen.ClientID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 {
if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 {
t.Fatalf("forwarded config = %+v", seen)
}

View File

@@ -32,7 +32,7 @@ type Config struct {
Engine string
URL string
Token string
ClientID string
DeviceID string
Name string
OnData func([]byte)
DNSServer string

View File

@@ -39,11 +39,11 @@ func TestNewAndAvailable(t *testing.T) {
called := false
Register("test-link", func(_ context.Context, cfg Config) (Link, error) {
called = cfg.ClientID == "client-1"
called = cfg.DeviceID == "client-1"
return &stubLink{}, nil
})
got, err := New(context.Background(), "test-link", Config{ClientID: "client-1"})
got, err := New(context.Background(), "test-link", Config{DeviceID: "client-1"})
if err != nil {
t.Fatalf("New() error = %v", err)
}

View File

@@ -13,7 +13,9 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
@@ -43,7 +45,9 @@ type Server struct {
sessMu sync.RWMutex
reinstallMu sync.Mutex
wg sync.WaitGroup
clientID string
authHook handshake.AuthFunc
deviceID string
sessionID string
dnsServer string
resolver *net.Resolver
socksProxyAddr string
@@ -52,10 +56,9 @@ type Server struct {
// ConnectRequest is a message from the client to establish a new connection.
type ConnectRequest struct {
Cmd string `json:"cmd"`
ClientID string `json:"clientId"`
Addr string `json:"addr"`
Port int `json:"port"`
Cmd string `json:"cmd"`
Addr string `json:"addr"`
Port int `json:"port"`
}
// Config holds runtime configuration for [Run].
@@ -65,7 +68,6 @@ type Config struct {
Carrier string
RoomURL string
KeyHex string
ClientID string
DNSServer string
SOCKSProxyAddr string
SOCKSProxyPort int
@@ -88,6 +90,10 @@ type Config struct {
Engine string
URL string
Token string
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
// return a session ID. If nil, every client is admitted with a random UUID.
AuthHook handshake.AuthFunc
}
// Run starts the server with the given configuration.
@@ -100,9 +106,14 @@ func Run(ctx context.Context, cfg Config) error {
return fmt.Errorf("setupCipher failed: %w", err)
}
hook := cfg.AuthHook
if hook == nil {
hook = defaultAuthHook
}
s := &Server{
cipher: cipher,
clientID: cfg.ClientID,
authHook: hook,
dnsServer: cfg.DNSServer,
socksProxyAddr: cfg.SOCKSProxyAddr,
socksProxyPort: cfg.SOCKSProxyPort,
@@ -182,7 +193,7 @@ func (s *Server) bringUpLink(
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
ClientID: s.clientID,
DeviceID: "",
Name: names.Generate(),
OnData: s.onData,
DNSServer: s.dnsServer,
@@ -270,6 +281,8 @@ func (s *Server) reinstallSession(dead *smux.Session) {
_ = s.conn.Close()
s.conn = nil
}
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
s.installSession()
}
@@ -284,6 +297,8 @@ func (s *Server) closeSession() {
_ = s.conn.Close()
s.conn = nil
}
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
}
@@ -296,9 +311,9 @@ func (s *Server) onData(data []byte) {
}
}
// serve drives the smux Accept loop, spawning a tunnel per inbound stream.
// The loop tolerates session bounces (reconnects) by waiting until a fresh
// session is installed instead of terminating the server.
// serve drives the smux Accept loop. The first accepted stream on a given
// smux session is the control stream — the handshake runs there. Subsequent
// streams are tunnel streams and proxy traffic.
func (s *Server) serve(ctx context.Context) {
for {
select {
@@ -319,6 +334,12 @@ func (s *Server) serve(ctx context.Context) {
}
}
if !s.handshakeReady() {
if !s.acceptHandshake(ctx, sess) {
continue
}
}
stream, err := sess.AcceptStream()
if err != nil {
select {
@@ -339,6 +360,62 @@ func (s *Server) serve(ctx context.Context) {
}
}
// handshakeReady reports whether the current session has completed its
// handshake. The session is reset on reconnect, so this is recomputed.
func (s *Server) handshakeReady() bool {
s.sessMu.RLock()
defer s.sessMu.RUnlock()
return s.sessionID != ""
}
func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
stream, err := sess.AcceptStream()
if err != nil {
select {
case <-ctx.Done():
return false
default:
}
logger.Debugf("AcceptStream(control) returned %v - reinstalling session", err)
s.reinstallSession(sess)
return false
}
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
hello, sid, err := handshake.Server(stream, s.authHook)
_ = stream.SetDeadline(time.Time{})
if err != nil {
logger.Warnf("handshake failed: %v", err)
_ = stream.Close()
s.reinstallSession(sess)
return false
}
s.sessMu.Lock()
s.deviceID = hello.DeviceID
s.sessionID = sid
s.sessMu.Unlock()
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
// The control stream stays open for the lifetime of the session;
// keep it parked in a goroutine so the smux session does not close it.
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.parkControlStream(stream)
}()
return true
}
// parkControlStream blocks reading from the control stream until it closes.
// Future control messages (kick, rate updates, etc.) would be dispatched here.
func (s *Server) parkControlStream(stream *smux.Stream) {
defer func() { _ = stream.Close() }()
buf := make([]byte, 64)
for {
if _, err := stream.Read(buf); err != nil {
return
}
}
}
func (s *Server) shutdown() {
s.closeSession()
if s.ln != nil {
@@ -362,10 +439,6 @@ func (s *Server) handleStream(_ context.Context, stream *smux.Stream) {
header = append(header, tmp[:n]...)
if req, ok := parseConnectRequest(header); ok {
_ = stream.SetReadDeadline(time.Time{})
if !s.authorizeRequest(req) {
logger.Warnf("sid=%d rejected: client_id mismatch", stream.ID())
return
}
s.dispatch(stream, req)
return
}
@@ -390,8 +463,10 @@ func parseConnectRequest(buf []byte) (ConnectRequest, bool) {
return req, true
}
func (s *Server) authorizeRequest(req ConnectRequest) bool {
return req.ClientID == s.clientID
// defaultAuthHook admits every client and assigns a random session ID.
// Replace it via [Config.AuthHook] to plug in real authorization.
func defaultAuthHook(_ string, _ map[string]any) (string, error) {
return uuid.NewString(), nil
}
func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {

View File

@@ -47,10 +47,9 @@ func TestSmuxConfig(t *testing.T) {
func TestParseConnectRequest(t *testing.T) {
buf, err := json.Marshal(ConnectRequest{
Cmd: "connect",
ClientID: "client-1", //nolint:goconst // test literal, repetition is intentional
Addr: "example.com", //nolint:goconst // test literal, repetition is intentional
Port: 443,
Cmd: "connect",
Addr: "example.com", //nolint:goconst // test literal, repetition is intentional
Port: 443,
})
if err != nil {
t.Fatalf("Marshal() error = %v", err)
@@ -60,7 +59,7 @@ func TestParseConnectRequest(t *testing.T) {
if !ok {
t.Fatal("parseConnectRequest() returned ok=false")
}
if req.ClientID != "client-1" || req.Addr != "example.com" || req.Port != 443 {
if req.Addr != "example.com" || req.Port != 443 {
t.Fatalf("parseConnectRequest() = %+v", req)
}
@@ -72,13 +71,13 @@ func TestParseConnectRequest(t *testing.T) {
}
}
func TestAuthorizeRequest(t *testing.T) {
s := &Server{clientID: "client-1"}
if !s.authorizeRequest(ConnectRequest{ClientID: "client-1"}) {
t.Fatal("authorizeRequest() rejected valid client")
func TestDefaultAuthHook(t *testing.T) {
sid, err := defaultAuthHook("dev", map[string]any{"x": 1})
if err != nil {
t.Fatalf("defaultAuthHook() err = %v", err)
}
if s.authorizeRequest(ConnectRequest{ClientID: "client-2"}) {
t.Fatal("authorizeRequest() accepted wrong client")
if sid == "" {
t.Fatal("defaultAuthHook() returned empty session id")
}
}
@@ -301,7 +300,7 @@ func TestSocks5ConnectTruncatesLongDomain(t *testing.T) {
}
}
func TestHandleStreamRejectsWrongClientID(t *testing.T) {
func TestHandleStreamDispatchAfterConnect(t *testing.T) {
a, b := net.Pipe()
defer func() {
_ = a.Close()
@@ -323,7 +322,7 @@ func TestHandleStreamRejectsWrongClientID(t *testing.T) {
go func() {
stream, err := serverSess.AcceptStream()
if err == nil {
(&Server{clientID: "expected"}).handleStream(context.Background(), stream)
(&Server{}).handleStream(context.Background(), stream)
}
close(done)
}()
@@ -333,10 +332,9 @@ func TestHandleStreamRejectsWrongClientID(t *testing.T) {
t.Fatalf("OpenStream() error = %v", err)
}
req, err := json.Marshal(ConnectRequest{
Cmd: "connect",
ClientID: "wrong",
Addr: "example.com",
Port: 443,
Cmd: "connect",
Addr: "127.0.0.1",
Port: 1, // unreachable port — dispatch will fail dial and exit
})
if err != nil {
t.Fatalf("Marshal() error = %v", err)

View File

@@ -41,7 +41,7 @@ type Config struct {
Engine string
URL string
Token string
ClientID string
DeviceID string
Name string
OnData func([]byte)
DNSServer string

View File

@@ -40,11 +40,11 @@ func TestNewAndAvailable(t *testing.T) {
called := false
Register("test-transport", func(_ context.Context, cfg Config) (Transport, error) {
called = cfg.ClientID == "client-1"
called = cfg.DeviceID == "client-1"
return &stubTransport{}, nil
})
got, err := New(context.Background(), "test-transport", Config{ClientID: "client-1"})
got, err := New(context.Background(), "test-transport", Config{DeviceID: "client-1"})
if err != nil {
t.Fatalf("New() error = %v", err)
}

View File

@@ -162,7 +162,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error)
writerDone: make(chan struct{}),
frameInterval: time.Second / time.Duration(fps),
batchSize: batchSize,
bindingToken: bindingToken(cfg.ClientID),
bindingToken: bindingToken(cfg.DeviceID),
localEpoch: randomEpoch(),
}

View File

@@ -92,7 +92,7 @@ func TestNewConnectSendCallbacksFeaturesAndClose(t *testing.T) {
trIface, err := New(context.Background(), transport.Config{
Carrier: name,
ClientID: "client",
DeviceID: "client",
VP8FPS: 30,
VP8BatchSize: 1,
})

View File

@@ -222,7 +222,7 @@ func Check(
Carrier: carrierName,
RoomURL: buildRoomURL(carrierName, roomID),
KeyHex: keyHex,
ClientID: clientID,
DeviceID: clientID,
LocalAddr: fmt.Sprintf("127.0.0.1:%d", socksPort),
DNSServer: defaultDNSServer,
VP8FPS: clampAtLeastOne(vp8FPS, 120),
@@ -305,7 +305,7 @@ func Ping(
Carrier: carrierName,
RoomURL: buildRoomURL(carrierName, roomID),
KeyHex: keyHex,
ClientID: clientID,
DeviceID: clientID,
LocalAddr: fmt.Sprintf("127.0.0.1:%d", socksPort),
DNSServer: defaultDNSServer,
VP8FPS: clampAtLeastOne(vp8FPS, 120),
@@ -550,7 +550,7 @@ func startWithConfig(
Carrier: carrierName,
RoomURL: roomURL,
KeyHex: keyHex,
ClientID: clientID,
DeviceID: clientID,
LocalAddr: fmt.Sprintf("127.0.0.1:%d", socksPort),
DNSServer: cfg.dnsServer,
SOCKSUser: socksUser,

View File

@@ -171,10 +171,10 @@ func TestStartWithInjectedRunnerLifecycle(t *testing.T) {
runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error {
if cfg.Link != defaultLink || cfg.Transport != dataTransport || cfg.Carrier != carrierJazz ||
cfg.RoomURL != "any" || cfg.ClientID != "client" || cfg.LocalAddr != "127.0.0.1:1080" ||
cfg.RoomURL != "any" || cfg.DeviceID != "client" || cfg.LocalAddr != "127.0.0.1:1080" ||
cfg.DNSServer != defaultDNSServer || cfg.VP8FPS != 60 || cfg.VP8BatchSize != 8 {
t.Fatalf("RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d",
cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.ClientID, cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize)
cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.DeviceID, cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize)
}
onReady()
<-ctx.Done()