mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
feat: remove unused client ID from config
This commit is contained in:
@@ -83,7 +83,6 @@ func TestRunWithConfigValidationAndDataDirErrors(t *testing.T) {
|
||||
Link: "direct",
|
||||
Transport: "datachannel",
|
||||
Auth: "jazz",
|
||||
ClientID: "client",
|
||||
KeyHex: "key",
|
||||
DNSServer: "1.1.1.1:53",
|
||||
}
|
||||
@@ -113,7 +112,7 @@ func TestRunWithArgsSuccessfulSessionReturn(t *testing.T) {
|
||||
called := false
|
||||
runSession = func(ctx context.Context, cfg session.Config) error {
|
||||
called = true
|
||||
if cfg.Mode != "srv" || cfg.Auth != "jazz" || cfg.ClientID != "client" {
|
||||
if cfg.Mode != "srv" || cfg.Auth != "jazz" {
|
||||
t.Fatalf("session config = %+v", cfg)
|
||||
}
|
||||
select {
|
||||
@@ -129,8 +128,6 @@ mode: srv
|
||||
link: direct
|
||||
auth:
|
||||
provider: jazz
|
||||
room:
|
||||
client_id: client
|
||||
crypto:
|
||||
key: key
|
||||
net:
|
||||
|
||||
@@ -101,8 +101,6 @@ var (
|
||||
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)")
|
||||
// ErrSOCKSPortRequired indicates that socks port is required for cnc mode.
|
||||
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)")
|
||||
// ErrClientIDRequired indicates that client ID is required.
|
||||
ErrClientIDRequired = errors.New("client ID required (use -client-id <id>)")
|
||||
)
|
||||
|
||||
// Config holds runtime session settings.
|
||||
@@ -115,7 +113,6 @@ type Config struct {
|
||||
URL string
|
||||
Token string
|
||||
RoomID string
|
||||
ClientID string
|
||||
KeyHex string
|
||||
SOCKSHost string
|
||||
SOCKSPort int
|
||||
@@ -242,9 +239,6 @@ func validateCommon(cfg Config) error {
|
||||
if cfg.RoomID == "" && cfg.Auth != authJazz && cfg.Auth != authNone {
|
||||
return ErrRoomIDRequired
|
||||
}
|
||||
if cfg.ClientID == "" {
|
||||
return ErrClientIDRequired
|
||||
}
|
||||
if cfg.KeyHex == "" {
|
||||
return ErrKeyRequired
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ func TestValidate(t *testing.T) {
|
||||
Transport: "datachannel",
|
||||
Auth: "telemost",
|
||||
RoomID: "room-1",
|
||||
ClientID: "client-1",
|
||||
KeyHex: "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff",
|
||||
DNSServer: "1.1.1.1:53", //nolint:goconst // test literal, repetition is intentional
|
||||
}
|
||||
@@ -91,15 +90,6 @@ func TestValidate(t *testing.T) {
|
||||
}(),
|
||||
want: ErrRoomIDRequired,
|
||||
},
|
||||
{
|
||||
name: "client id required",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.ClientID = ""
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrClientIDRequired,
|
||||
},
|
||||
{
|
||||
name: "key required",
|
||||
cfg: func() Config {
|
||||
|
||||
@@ -10,10 +10,15 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
@@ -44,15 +49,18 @@ var (
|
||||
|
||||
// Client handles local SOCKS5 connections and tunnels them to the server.
|
||||
type Client struct {
|
||||
ln link.Link
|
||||
cipher *crypto.Cipher
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
sessMu sync.RWMutex
|
||||
clientID string
|
||||
dnsServer string
|
||||
socksUser string
|
||||
socksPass string
|
||||
ln link.Link
|
||||
cipher *crypto.Cipher
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
controlStrm *smux.Stream
|
||||
sessMu sync.RWMutex
|
||||
deviceID string
|
||||
sessionID string
|
||||
claims map[string]any
|
||||
dnsServer string
|
||||
socksUser string
|
||||
socksPass string
|
||||
}
|
||||
|
||||
// Config holds runtime configuration for [Run] and [RunWithReady].
|
||||
@@ -62,7 +70,6 @@ type Config struct {
|
||||
Carrier string
|
||||
RoomURL string
|
||||
KeyHex string
|
||||
ClientID string
|
||||
LocalAddr string
|
||||
DNSServer string
|
||||
SOCKSUser string
|
||||
@@ -86,6 +93,19 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
|
||||
// DeviceID overrides the persistent client-side device identifier. Leave
|
||||
// empty to derive one from DeviceIDPath (or generate a random one if both
|
||||
// are empty).
|
||||
DeviceID string
|
||||
|
||||
// DeviceIDPath is a file in which to persist the auto-generated device ID
|
||||
// across restarts. Ignored when DeviceID is set explicitly.
|
||||
DeviceIDPath string
|
||||
|
||||
// Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to
|
||||
// the server's AuthHook. Free-form key/value bag for plan, user, region, etc.
|
||||
Claims map[string]any
|
||||
}
|
||||
|
||||
// Run starts the client with the given configuration.
|
||||
@@ -103,9 +123,15 @@ func RunWithReady(ctx context.Context, cfg Config, onReady func()) error {
|
||||
return fmt.Errorf("setupCipher failed: %w", err)
|
||||
}
|
||||
|
||||
deviceID, err := resolveDeviceID(cfg.DeviceID, cfg.DeviceIDPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve device id: %w", err)
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
cipher: cipher,
|
||||
clientID: cfg.ClientID,
|
||||
deviceID: deviceID,
|
||||
claims: cfg.Claims,
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksUser: cfg.SOCKSUser,
|
||||
socksPass: cfg.SOCKSPass,
|
||||
@@ -147,7 +173,7 @@ func (c *Client) bringUpLink(
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
ClientID: c.clientID,
|
||||
DeviceID: c.deviceID,
|
||||
Name: names.Generate(),
|
||||
OnData: c.onData,
|
||||
DNSServer: cfg.DNSServer,
|
||||
@@ -188,14 +214,80 @@ func (c *Client) bringUpLink(
|
||||
if err != nil {
|
||||
return fmt.Errorf("smux client: %w", err)
|
||||
}
|
||||
|
||||
control, sid, err := openControlStream(sess, c.deviceID, c.claims)
|
||||
if err != nil {
|
||||
_ = sess.Close()
|
||||
_ = c.conn.Close()
|
||||
return fmt.Errorf("handshake: %w", err)
|
||||
}
|
||||
logger.Infof("session %s opened (device=%s)", sid, c.deviceID)
|
||||
|
||||
c.sessMu.Lock()
|
||||
c.session = sess
|
||||
c.controlStrm = control
|
||||
c.sessionID = sid
|
||||
c.sessMu.Unlock()
|
||||
|
||||
go ln.WatchConnection(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// openControlStream opens stream #1 on sess and performs the handshake.
|
||||
// The stream stays open for the lifetime of the smux session — the server
|
||||
// holds it parked, and it would carry future control messages.
|
||||
func openControlStream(
|
||||
sess *smux.Session,
|
||||
deviceID string,
|
||||
claims map[string]any,
|
||||
) (*smux.Stream, string, error) {
|
||||
stream, err := sess.OpenStream()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("open control stream: %w", err)
|
||||
}
|
||||
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
|
||||
sid, err := handshake.Client(stream, deviceID, claims)
|
||||
_ = stream.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
return stream, sid, nil
|
||||
}
|
||||
|
||||
// resolveDeviceID returns the device ID to send in CLIENT_HELLO.
|
||||
//
|
||||
// Precedence:
|
||||
// 1. Explicit deviceID arg (Config.DeviceID) — used verbatim.
|
||||
// 2. Persistent file at path (Config.DeviceIDPath) — read if it exists,
|
||||
// otherwise generated and written for future runs.
|
||||
// 3. Random UUID per run when both inputs are empty.
|
||||
func resolveDeviceID(deviceID, path string) (string, error) {
|
||||
if deviceID != "" {
|
||||
return deviceID, nil
|
||||
}
|
||||
if path == "" {
|
||||
return uuid.NewString(), nil
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err == nil {
|
||||
id := strings.TrimSpace(string(data))
|
||||
if id != "" {
|
||||
return id, nil
|
||||
}
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
return "", fmt.Errorf("read device id %s: %w", path, err)
|
||||
}
|
||||
id := uuid.NewString()
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return "", fmt.Errorf("mkdir device id dir: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(id+"\n"), 0o600); err != nil {
|
||||
return "", fmt.Errorf("write device id %s: %w", path, err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// smuxConfig returns the tuned smux config used on both ends.
|
||||
func smuxConfig() *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
@@ -212,6 +304,10 @@ func smuxConfig() *smux.Config {
|
||||
func (c *Client) handleReconnect() {
|
||||
logger.Infof("client link reconnect - tearing down smux session")
|
||||
c.sessMu.Lock()
|
||||
if c.controlStrm != nil {
|
||||
_ = c.controlStrm.Close()
|
||||
c.controlStrm = nil
|
||||
}
|
||||
if c.session != nil {
|
||||
_ = c.session.Close()
|
||||
c.session = nil
|
||||
@@ -220,6 +316,7 @@ func (c *Client) handleReconnect() {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.sessionID = ""
|
||||
c.sessMu.Unlock()
|
||||
c.conn = muxconn.New(c.ln, c.cipher)
|
||||
sess, err := smux.Client(c.conn, smuxConfig())
|
||||
@@ -227,13 +324,25 @@ func (c *Client) handleReconnect() {
|
||||
logger.Warnf("smux re-init failed: %v", err)
|
||||
return
|
||||
}
|
||||
control, sid, err := openControlStream(sess, c.deviceID, c.claims)
|
||||
if err != nil {
|
||||
logger.Warnf("handshake on reconnect failed: %v", err)
|
||||
_ = sess.Close()
|
||||
return
|
||||
}
|
||||
logger.Infof("session %s reopened (device=%s)", sid, c.deviceID)
|
||||
c.sessMu.Lock()
|
||||
c.session = sess
|
||||
c.controlStrm = control
|
||||
c.sessionID = sid
|
||||
c.sessMu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) shutdown() {
|
||||
c.sessMu.Lock()
|
||||
if c.controlStrm != nil {
|
||||
_ = c.controlStrm.Close()
|
||||
}
|
||||
if c.session != nil {
|
||||
_ = c.session.Close()
|
||||
}
|
||||
@@ -340,10 +449,9 @@ func (c *Client) tunnel(conn net.Conn, sess *smux.Session, targetAddr string, ta
|
||||
|
||||
func (c *Client) sendConnectRequest(stream *smux.Stream, targetAddr string, targetPort int) error {
|
||||
connectReq, err := json.Marshal(map[string]any{
|
||||
"cmd": "connect",
|
||||
"clientId": c.clientID,
|
||||
"addr": targetAddr,
|
||||
"port": targetPort,
|
||||
"cmd": "connect",
|
||||
"addr": targetAddr,
|
||||
"port": targetPort,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sid=%d marshal connect req: %w", stream.ID(), err)
|
||||
|
||||
@@ -417,7 +417,7 @@ func TestSendConnectRequestOverSmux(t *testing.T) {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
if req["cmd"] != "connect" || req["clientId"] != "client-1" || req["addr"] != "example.com" { //nolint:goconst,lll // test literal, repetition is intentional
|
||||
if req["cmd"] != "connect" || req["addr"] != "example.com" { //nolint:goconst,lll // test literal, repetition is intentional
|
||||
done <- errUnexpectedConnectRequest
|
||||
return
|
||||
}
|
||||
@@ -431,7 +431,7 @@ func TestSendConnectRequestOverSmux(t *testing.T) {
|
||||
}
|
||||
defer func() { _ = stream.Close() }()
|
||||
|
||||
c := &Client{clientID: "client-1"}
|
||||
c := &Client{deviceID: "client-1"}
|
||||
if err := c.sendConnectRequest(stream, "example.com", 443); err != nil {
|
||||
t.Fatalf("sendConnectRequest() error = %v", err)
|
||||
}
|
||||
@@ -473,7 +473,7 @@ func TestSendConnectRequestRejectsBadAck(t *testing.T) {
|
||||
}
|
||||
defer func() { _ = stream.Close() }()
|
||||
|
||||
c := &Client{clientID: "client-1"}
|
||||
c := &Client{deviceID: "client-1"}
|
||||
if err := c.sendConnectRequest(stream, "example.com", 443); !errors.Is(err, ErrRemoteNotReady) {
|
||||
t.Fatalf("sendConnectRequest() error = %v, want %v", err, ErrRemoteNotReady)
|
||||
}
|
||||
|
||||
@@ -45,8 +45,7 @@ type Auth struct {
|
||||
|
||||
// Room identifies the conference room.
|
||||
type Room struct {
|
||||
ID string `yaml:"id"`
|
||||
ClientID string `yaml:"client_id"` // deprecated: server identifier (will be removed)
|
||||
ID string `yaml:"id"`
|
||||
}
|
||||
|
||||
// Crypto holds the shared secret used to authenticate and encrypt the tunnel.
|
||||
@@ -137,7 +136,6 @@ func Apply(dst session.Config, f File) session.Config {
|
||||
dst.URL = pickString(dst.URL, f.Engine.URL)
|
||||
dst.Token = pickString(dst.Token, f.Engine.Token)
|
||||
dst.RoomID = pickString(dst.RoomID, f.Room.ID)
|
||||
dst.ClientID = pickString(dst.ClientID, f.Room.ClientID)
|
||||
dst.KeyHex = pickString(dst.KeyHex, f.Crypto.Key)
|
||||
dst.SOCKSHost = pickString(dst.SOCKSHost, f.SOCKS.Host)
|
||||
dst.SOCKSPort = pickInt(dst.SOCKSPort, f.SOCKS.Port)
|
||||
|
||||
@@ -18,7 +18,6 @@ auth:
|
||||
provider: wbstream
|
||||
room:
|
||||
id: r1
|
||||
client_id: c1
|
||||
crypto:
|
||||
key: deadbeef
|
||||
net:
|
||||
@@ -50,7 +49,7 @@ debug: true
|
||||
|
||||
got := Apply(session.Config{}, f)
|
||||
if got.Mode != "srv" || got.Link != "direct" || got.Auth != "wbstream" ||
|
||||
got.RoomID != "r1" || got.ClientID != "c1" || got.KeyHex != "deadbeef" ||
|
||||
got.RoomID != "r1" || got.KeyHex != "deadbeef" ||
|
||||
got.Transport != "datachannel" || got.DNSServer != "1.1.1.1:53" ||
|
||||
got.SOCKSHost != "127.0.0.1" || got.SOCKSPort != 1080 ||
|
||||
got.SOCKSUser != "u" || got.SOCKSPass != "p" ||
|
||||
|
||||
@@ -400,7 +400,6 @@ func validSessionConfig(mode, carrierName, transportName string) session.Config
|
||||
Transport: transportName,
|
||||
Auth: carrierName,
|
||||
RoomID: "room",
|
||||
ClientID: "client-1",
|
||||
KeyHex: testKeyHex,
|
||||
SOCKSHost: "127.0.0.1",
|
||||
SOCKSPort: 1080,
|
||||
@@ -428,7 +427,7 @@ func validLinkConfig(carrierName, transportName string) link.Config {
|
||||
Transport: cfg.Transport,
|
||||
Carrier: cfg.Auth,
|
||||
RoomURL: "room",
|
||||
ClientID: cfg.ClientID,
|
||||
DeviceID: "e2e-link-test",
|
||||
Name: "e2e-" + carrierName + "-" + transportName,
|
||||
DNSServer: cfg.DNSServer,
|
||||
VideoWidth: cfg.VideoWidth,
|
||||
@@ -505,7 +504,7 @@ type tunnelRuntime struct {
|
||||
clientErr chan error
|
||||
}
|
||||
|
||||
func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRuntime {
|
||||
func startTunnel(t *testing.T, deviceID, _ string) *tunnelRuntime {
|
||||
t.Helper()
|
||||
|
||||
carrierName, room := registerMemoryCarrier(t)
|
||||
@@ -521,7 +520,6 @@ func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRun
|
||||
Carrier: carrierName,
|
||||
RoomURL: "room",
|
||||
KeyHex: testKeyHex,
|
||||
ClientID: serverClientID,
|
||||
DNSServer: "127.0.0.1:53",
|
||||
})
|
||||
}()
|
||||
@@ -536,7 +534,7 @@ func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRun
|
||||
Carrier: carrierName,
|
||||
RoomURL: "room",
|
||||
KeyHex: testKeyHex,
|
||||
ClientID: clientClientID,
|
||||
DeviceID: deviceID,
|
||||
LocalAddr: socksAddr,
|
||||
DNSServer: "127.0.0.1:53",
|
||||
}, func() { close(ready) })
|
||||
@@ -555,7 +553,7 @@ func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRun
|
||||
func startRealTunnel(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
carrierName, transportName, roomURL, serverClientID, clientClientID string,
|
||||
carrierName, transportName, roomURL, _, clientDeviceID string,
|
||||
) (*tunnelRuntime, error) {
|
||||
t.Helper()
|
||||
|
||||
@@ -573,7 +571,6 @@ func startRealTunnel(
|
||||
Carrier: carrierName,
|
||||
RoomURL: roomURL,
|
||||
KeyHex: testKeyHex,
|
||||
ClientID: serverClientID,
|
||||
DNSServer: "127.0.0.1:53",
|
||||
VideoWidth: 1080,
|
||||
VideoHeight: 1080,
|
||||
@@ -613,7 +610,7 @@ func startRealTunnel(
|
||||
Carrier: carrierName,
|
||||
RoomURL: roomURL,
|
||||
KeyHex: testKeyHex,
|
||||
ClientID: clientClientID,
|
||||
DeviceID: clientDeviceID,
|
||||
LocalAddr: socksAddr,
|
||||
DNSServer: "127.0.0.1:53",
|
||||
VideoWidth: 1080,
|
||||
@@ -749,49 +746,6 @@ func connectViaSOCKS(t *testing.T, socksAddr, targetAddr string) net.Conn {
|
||||
return conn
|
||||
}
|
||||
|
||||
func connectViaSOCKSExpectFailure(t *testing.T, socksAddr, targetAddr string) []byte {
|
||||
t.Helper()
|
||||
|
||||
dialer := net.Dialer{Timeout: 2 * time.Second}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp4", socksAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("dial socks: %v", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
|
||||
t.Fatalf("write socks greeting: %v", err)
|
||||
}
|
||||
greeting := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, greeting); err != nil {
|
||||
t.Fatalf("read socks greeting: %v", err)
|
||||
}
|
||||
|
||||
host, portText, err := net.SplitHostPort(targetAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("split target addr: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portText)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target port: %v", err)
|
||||
}
|
||||
req := make([]byte, 0, 10)
|
||||
req = append(req, 5, 1, 0, 1)
|
||||
req = append(req, net.ParseIP(host).To4()...)
|
||||
var portBuf [2]byte
|
||||
binary.BigEndian.PutUint16(portBuf[:], uint16(port)) //nolint:gosec // SOCKS5 port is uint16 by definition
|
||||
req = append(req, portBuf[:]...)
|
||||
if _, err := conn.Write(req); err != nil {
|
||||
t.Fatalf("write socks connect: %v", err)
|
||||
}
|
||||
|
||||
reply := make([]byte, 10)
|
||||
if _, err := io.ReadFull(conn, reply); err != nil {
|
||||
t.Fatalf("read socks failure reply: %v", err)
|
||||
}
|
||||
return reply
|
||||
}
|
||||
|
||||
func TestBuiltInProviderTransportMatrixValidates(t *testing.T) {
|
||||
session.RegisterDefaults()
|
||||
|
||||
@@ -971,17 +925,6 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrongClientIDIsRejected(t *testing.T) {
|
||||
echoAddr := startEchoServer(t)
|
||||
rt := startTunnel(t, "server-client", "wrong-client")
|
||||
defer rt.stop(t)
|
||||
|
||||
reply := connectViaSOCKSExpectFailure(t, rt.socksAddr, echoAddr)
|
||||
if !bytes.Equal(reply, []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}) {
|
||||
t.Fatalf("wrong client-id reply = %v, want host unreachable", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) {
|
||||
echoAddr := startEchoServer(t)
|
||||
rt := startTunnel(t, "client-1", "client-1")
|
||||
|
||||
214
internal/handshake/handshake.go
Normal file
214
internal/handshake/handshake.go
Normal file
@@ -0,0 +1,214 @@
|
||||
// Package handshake implements the olcrtc session handshake.
|
||||
//
|
||||
// The handshake runs on the first smux stream (control stream) of a tunnel.
|
||||
// Wire format on the control stream is length-prefixed JSON: each message is
|
||||
// a 4-byte big-endian length followed by that many bytes of JSON.
|
||||
//
|
||||
// client server
|
||||
// │ CLIENT_HELLO │
|
||||
// │ ─────────────────────► │
|
||||
// │ │ AuthHook(claims) → sessionID | err
|
||||
// │ SERVER_WELCOME / REJECT│
|
||||
// │ ◄───────────────────── │
|
||||
// │ │
|
||||
//
|
||||
// After the exchange the control stream stays open; tunnel traffic flows over
|
||||
// additional smux streams opened by the client. The control stream may carry
|
||||
// keepalives or future control messages.
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProtoVersion identifies the wire-format version. Bumped only on breaking
|
||||
// changes to message layout or semantics.
|
||||
const ProtoVersion = 1
|
||||
|
||||
// MaxMessageSize caps a single handshake frame. 64 KiB is comfortably larger
|
||||
// than any legitimate HELLO/WELCOME payload and prevents memory blowups from
|
||||
// malicious peers.
|
||||
const MaxMessageSize = 64 * 1024
|
||||
|
||||
// DefaultTimeout bounds how long either side will wait for the peer's reply
|
||||
// before bailing out.
|
||||
const DefaultTimeout = 15 * time.Second
|
||||
|
||||
// MsgType labels each protocol message.
|
||||
type MsgType string
|
||||
|
||||
const (
|
||||
// TypeHello is the client's first message.
|
||||
TypeHello MsgType = "CLIENT_HELLO"
|
||||
// TypeWelcome is the server's success reply.
|
||||
TypeWelcome MsgType = "SERVER_WELCOME"
|
||||
// TypeReject is the server's failure reply.
|
||||
TypeReject MsgType = "SERVER_REJECT"
|
||||
)
|
||||
|
||||
// Hello is sent by the client to begin a session.
|
||||
type Hello struct {
|
||||
Version int `json:"version"`
|
||||
Type MsgType `json:"type"`
|
||||
DeviceID string `json:"device_id"`
|
||||
Claims map[string]any `json:"claims,omitempty"`
|
||||
}
|
||||
|
||||
// Welcome is the server's response on a successful handshake.
|
||||
type Welcome struct {
|
||||
Version int `json:"version"`
|
||||
Type MsgType `json:"type"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// Reject is the server's response when auth fails.
|
||||
type Reject struct {
|
||||
Version int `json:"version"`
|
||||
Type MsgType `json:"type"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// Errors returned by [Client] and [Server].
|
||||
var (
|
||||
// ErrRejected wraps a server-side rejection. The reason is in the error message.
|
||||
ErrRejected = errors.New("handshake rejected")
|
||||
// ErrProtocolVersion is returned when peer announces an incompatible version.
|
||||
ErrProtocolVersion = errors.New("incompatible protocol version")
|
||||
// ErrUnexpectedMessage is returned when a peer sends the wrong message type.
|
||||
ErrUnexpectedMessage = errors.New("unexpected handshake message")
|
||||
// ErrFrameTooLarge is returned when a peer announces a frame above [MaxMessageSize].
|
||||
ErrFrameTooLarge = errors.New("handshake frame too large")
|
||||
)
|
||||
|
||||
// AuthFunc is invoked by [Server] after parsing CLIENT_HELLO.
|
||||
// It returns the session ID to send back to the client, or an error to reject
|
||||
// the connection. The error's message is forwarded to the client as the
|
||||
// reject reason, so it should not leak sensitive details.
|
||||
type AuthFunc func(deviceID string, claims map[string]any) (sessionID string, err error)
|
||||
|
||||
// Client performs the client side of the handshake on rw and returns the
|
||||
// session ID assigned by the server.
|
||||
func Client(rw io.ReadWriter, deviceID string, claims map[string]any) (string, error) {
|
||||
hello := Hello{
|
||||
Version: ProtoVersion,
|
||||
Type: TypeHello,
|
||||
DeviceID: deviceID,
|
||||
Claims: claims,
|
||||
}
|
||||
if err := writeFrame(rw, hello); err != nil {
|
||||
return "", fmt.Errorf("send hello: %w", err)
|
||||
}
|
||||
|
||||
raw, err := readFrame(rw)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read welcome: %w", err)
|
||||
}
|
||||
|
||||
var probe struct {
|
||||
Type MsgType `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &probe); err != nil {
|
||||
return "", fmt.Errorf("parse reply: %w", err)
|
||||
}
|
||||
|
||||
switch probe.Type {
|
||||
case TypeWelcome:
|
||||
var w Welcome
|
||||
if err := json.Unmarshal(raw, &w); err != nil {
|
||||
return "", fmt.Errorf("parse welcome: %w", err)
|
||||
}
|
||||
if w.Version != ProtoVersion {
|
||||
return "", fmt.Errorf("%w: server v%d, client v%d",
|
||||
ErrProtocolVersion, w.Version, ProtoVersion)
|
||||
}
|
||||
return w.SessionID, nil
|
||||
case TypeReject:
|
||||
var r Reject
|
||||
if err := json.Unmarshal(raw, &r); err != nil {
|
||||
return "", fmt.Errorf("parse reject: %w", err)
|
||||
}
|
||||
return "", fmt.Errorf("%w: %s", ErrRejected, r.Reason)
|
||||
default:
|
||||
return "", fmt.Errorf("%w: got %q", ErrUnexpectedMessage, probe.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Server performs the server side of the handshake. It reads CLIENT_HELLO,
|
||||
// invokes auth, and writes the corresponding WELCOME or REJECT. On success it
|
||||
// returns the parsed Hello and the session ID produced by auth.
|
||||
func Server(rw io.ReadWriter, auth AuthFunc) (Hello, string, error) {
|
||||
raw, err := readFrame(rw)
|
||||
if err != nil {
|
||||
return Hello{}, "", fmt.Errorf("read hello: %w", err)
|
||||
}
|
||||
|
||||
var h Hello
|
||||
if err := json.Unmarshal(raw, &h); err != nil {
|
||||
_ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: "malformed hello"})
|
||||
return Hello{}, "", fmt.Errorf("parse hello: %w", err)
|
||||
}
|
||||
if h.Type != TypeHello {
|
||||
_ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: "expected CLIENT_HELLO"})
|
||||
return h, "", fmt.Errorf("%w: got %q", ErrUnexpectedMessage, h.Type)
|
||||
}
|
||||
if h.Version != ProtoVersion {
|
||||
_ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: "protocol version mismatch"})
|
||||
return h, "", fmt.Errorf("%w: client v%d, server v%d",
|
||||
ErrProtocolVersion, h.Version, ProtoVersion)
|
||||
}
|
||||
|
||||
sessionID, err := auth(h.DeviceID, h.Claims)
|
||||
if err != nil {
|
||||
_ = writeFrame(rw, Reject{Version: ProtoVersion, Type: TypeReject, Reason: err.Error()})
|
||||
return h, "", fmt.Errorf("auth: %w", err)
|
||||
}
|
||||
|
||||
if err := writeFrame(rw, Welcome{
|
||||
Version: ProtoVersion,
|
||||
Type: TypeWelcome,
|
||||
SessionID: sessionID,
|
||||
}); err != nil {
|
||||
return h, sessionID, fmt.Errorf("send welcome: %w", err)
|
||||
}
|
||||
return h, sessionID, nil
|
||||
}
|
||||
|
||||
func writeFrame(w io.Writer, msg any) error {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
if len(body) > MaxMessageSize {
|
||||
return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize)
|
||||
}
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize
|
||||
if _, err := w.Write(hdr[:]); err != nil {
|
||||
return fmt.Errorf("write hdr: %w", err)
|
||||
}
|
||||
if _, err := w.Write(body); err != nil {
|
||||
return fmt.Errorf("write body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readFrame(r io.Reader) ([]byte, error) {
|
||||
var hdr [4]byte
|
||||
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||
return nil, fmt.Errorf("read hdr: %w", err)
|
||||
}
|
||||
n := binary.BigEndian.Uint32(hdr[:])
|
||||
if n > MaxMessageSize {
|
||||
return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize)
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return nil, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
128
internal/handshake/handshake_test.go
Normal file
128
internal/handshake/handshake_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func pair(t *testing.T) (net.Conn, net.Conn) {
|
||||
t.Helper()
|
||||
a, b := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
})
|
||||
return a, b
|
||||
}
|
||||
|
||||
func TestHandshakeRoundTrip(t *testing.T) {
|
||||
cConn, sConn := pair(t)
|
||||
|
||||
go func() {
|
||||
hello, sid, err := Server(sConn, func(deviceID string, claims map[string]any) (string, error) {
|
||||
if deviceID != "dev-1" {
|
||||
t.Errorf("device id = %q", deviceID)
|
||||
}
|
||||
if claims["plan"] != "pro" {
|
||||
t.Errorf("claims = %v", claims)
|
||||
}
|
||||
return "sess-42", nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Server: %v", err)
|
||||
}
|
||||
if hello.DeviceID != "dev-1" || sid != "sess-42" {
|
||||
t.Errorf("Server returned hello=%+v sid=%q", hello, sid)
|
||||
}
|
||||
}()
|
||||
|
||||
sid, err := Client(cConn, "dev-1", map[string]any{"plan": "pro"})
|
||||
if err != nil {
|
||||
t.Fatalf("Client: %v", err)
|
||||
}
|
||||
if sid != "sess-42" {
|
||||
t.Fatalf("session id = %q, want sess-42", sid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandshakeRejected(t *testing.T) {
|
||||
cConn, sConn := pair(t)
|
||||
|
||||
go func() {
|
||||
_, _, _ = Server(sConn, func(string, map[string]any) (string, error) {
|
||||
return "", errors.New("nope")
|
||||
})
|
||||
}()
|
||||
|
||||
_, err := Client(cConn, "dev-1", nil)
|
||||
if !errors.Is(err, ErrRejected) {
|
||||
t.Fatalf("Client err = %v, want ErrRejected", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "nope") {
|
||||
t.Fatalf("err message %q missing reason", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandshakeProtocolMismatch(t *testing.T) {
|
||||
cConn, sConn := pair(t)
|
||||
|
||||
go func() {
|
||||
_ = writeFrame(cConn, Hello{Version: 999, Type: TypeHello, DeviceID: "dev"})
|
||||
_, _ = readFrame(cConn) // drain server's REJECT so its write does not block
|
||||
}()
|
||||
|
||||
_, _, err := Server(sConn, func(string, map[string]any) (string, error) {
|
||||
t.Fatal("auth must not be invoked on protocol mismatch")
|
||||
return "", nil
|
||||
})
|
||||
if !errors.Is(err, ErrProtocolVersion) {
|
||||
t.Fatalf("Server err = %v, want ErrProtocolVersion", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandshakeUnexpectedType(t *testing.T) {
|
||||
cConn, sConn := pair(t)
|
||||
|
||||
go func() {
|
||||
_ = writeFrame(cConn, Hello{Version: ProtoVersion, Type: "BOGUS", DeviceID: "dev"})
|
||||
_, _ = readFrame(cConn) // drain server's REJECT
|
||||
}()
|
||||
|
||||
_, _, err := Server(sConn, func(string, map[string]any) (string, error) {
|
||||
t.Fatal("auth must not be invoked on bad type")
|
||||
return "", nil
|
||||
})
|
||||
if !errors.Is(err, ErrUnexpectedMessage) {
|
||||
t.Fatalf("Server err = %v, want ErrUnexpectedMessage", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFrameTooLarge(t *testing.T) {
|
||||
cConn, sConn := pair(t)
|
||||
|
||||
go func() {
|
||||
var hdr [4]byte
|
||||
hdr[0] = 0xff
|
||||
hdr[1] = 0xff
|
||||
_, _ = cConn.Write(hdr[:])
|
||||
_ = cConn.Close()
|
||||
}()
|
||||
|
||||
_, err := readFrame(sConn)
|
||||
if !errors.Is(err, ErrFrameTooLarge) {
|
||||
t.Fatalf("readFrame err = %v, want ErrFrameTooLarge", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFrameEOF(t *testing.T) {
|
||||
cConn, sConn := pair(t)
|
||||
_ = cConn.Close()
|
||||
|
||||
_, err := readFrame(sConn)
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.Fatalf("readFrame err = %v", err)
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) {
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
ClientID: cfg.ClientID,
|
||||
DeviceID: cfg.DeviceID,
|
||||
Name: cfg.Name,
|
||||
OnData: cfg.OnData,
|
||||
DNSServer: cfg.DNSServer,
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
|
||||
Transport: name,
|
||||
Carrier: "carrier",
|
||||
RoomURL: "room",
|
||||
ClientID: "client",
|
||||
DeviceID: "client",
|
||||
Name: "peer",
|
||||
DNSServer: "1.1.1.1:53",
|
||||
ProxyAddr: "127.0.0.1",
|
||||
@@ -84,7 +84,7 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
if seen.ClientID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 {
|
||||
if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 {
|
||||
t.Fatalf("forwarded config = %+v", seen)
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
ClientID string
|
||||
DeviceID string
|
||||
Name string
|
||||
OnData func([]byte)
|
||||
DNSServer string
|
||||
|
||||
@@ -39,11 +39,11 @@ func TestNewAndAvailable(t *testing.T) {
|
||||
|
||||
called := false
|
||||
Register("test-link", func(_ context.Context, cfg Config) (Link, error) {
|
||||
called = cfg.ClientID == "client-1"
|
||||
called = cfg.DeviceID == "client-1"
|
||||
return &stubLink{}, nil
|
||||
})
|
||||
|
||||
got, err := New(context.Background(), "test-link", Config{ClientID: "client-1"})
|
||||
got, err := New(context.Background(), "test-link", Config{DeviceID: "client-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
@@ -43,7 +45,9 @@ type Server struct {
|
||||
sessMu sync.RWMutex
|
||||
reinstallMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
clientID string
|
||||
authHook handshake.AuthFunc
|
||||
deviceID string
|
||||
sessionID string
|
||||
dnsServer string
|
||||
resolver *net.Resolver
|
||||
socksProxyAddr string
|
||||
@@ -52,10 +56,9 @@ type Server struct {
|
||||
|
||||
// ConnectRequest is a message from the client to establish a new connection.
|
||||
type ConnectRequest struct {
|
||||
Cmd string `json:"cmd"`
|
||||
ClientID string `json:"clientId"`
|
||||
Addr string `json:"addr"`
|
||||
Port int `json:"port"`
|
||||
Cmd string `json:"cmd"`
|
||||
Addr string `json:"addr"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
// Config holds runtime configuration for [Run].
|
||||
@@ -65,7 +68,6 @@ type Config struct {
|
||||
Carrier string
|
||||
RoomURL string
|
||||
KeyHex string
|
||||
ClientID string
|
||||
DNSServer string
|
||||
SOCKSProxyAddr string
|
||||
SOCKSProxyPort int
|
||||
@@ -88,6 +90,10 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
|
||||
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
|
||||
// return a session ID. If nil, every client is admitted with a random UUID.
|
||||
AuthHook handshake.AuthFunc
|
||||
}
|
||||
|
||||
// Run starts the server with the given configuration.
|
||||
@@ -100,9 +106,14 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
return fmt.Errorf("setupCipher failed: %w", err)
|
||||
}
|
||||
|
||||
hook := cfg.AuthHook
|
||||
if hook == nil {
|
||||
hook = defaultAuthHook
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cipher: cipher,
|
||||
clientID: cfg.ClientID,
|
||||
authHook: hook,
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksProxyAddr: cfg.SOCKSProxyAddr,
|
||||
socksProxyPort: cfg.SOCKSProxyPort,
|
||||
@@ -182,7 +193,7 @@ func (s *Server) bringUpLink(
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
ClientID: s.clientID,
|
||||
DeviceID: "",
|
||||
Name: names.Generate(),
|
||||
OnData: s.onData,
|
||||
DNSServer: s.dnsServer,
|
||||
@@ -270,6 +281,8 @@ func (s *Server) reinstallSession(dead *smux.Session) {
|
||||
_ = s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
s.installSession()
|
||||
}
|
||||
@@ -284,6 +297,8 @@ func (s *Server) closeSession() {
|
||||
_ = s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -296,9 +311,9 @@ func (s *Server) onData(data []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
// serve drives the smux Accept loop, spawning a tunnel per inbound stream.
|
||||
// The loop tolerates session bounces (reconnects) by waiting until a fresh
|
||||
// session is installed instead of terminating the server.
|
||||
// serve drives the smux Accept loop. The first accepted stream on a given
|
||||
// smux session is the control stream — the handshake runs there. Subsequent
|
||||
// streams are tunnel streams and proxy traffic.
|
||||
func (s *Server) serve(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
@@ -319,6 +334,12 @@ func (s *Server) serve(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if !s.handshakeReady() {
|
||||
if !s.acceptHandshake(ctx, sess) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
stream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
select {
|
||||
@@ -339,6 +360,62 @@ func (s *Server) serve(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// handshakeReady reports whether the current session has completed its
|
||||
// handshake. The session is reset on reconnect, so this is recomputed.
|
||||
func (s *Server) handshakeReady() bool {
|
||||
s.sessMu.RLock()
|
||||
defer s.sessMu.RUnlock()
|
||||
return s.sessionID != ""
|
||||
}
|
||||
|
||||
func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
|
||||
stream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
default:
|
||||
}
|
||||
logger.Debugf("AcceptStream(control) returned %v - reinstalling session", err)
|
||||
s.reinstallSession(sess)
|
||||
return false
|
||||
}
|
||||
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
|
||||
hello, sid, err := handshake.Server(stream, s.authHook)
|
||||
_ = stream.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
logger.Warnf("handshake failed: %v", err)
|
||||
_ = stream.Close()
|
||||
s.reinstallSession(sess)
|
||||
return false
|
||||
}
|
||||
s.sessMu.Lock()
|
||||
s.deviceID = hello.DeviceID
|
||||
s.sessionID = sid
|
||||
s.sessMu.Unlock()
|
||||
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
|
||||
// The control stream stays open for the lifetime of the session;
|
||||
// keep it parked in a goroutine so the smux session does not close it.
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.parkControlStream(stream)
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// parkControlStream blocks reading from the control stream until it closes.
|
||||
// Future control messages (kick, rate updates, etc.) would be dispatched here.
|
||||
func (s *Server) parkControlStream(stream *smux.Stream) {
|
||||
defer func() { _ = stream.Close() }()
|
||||
buf := make([]byte, 64)
|
||||
for {
|
||||
if _, err := stream.Read(buf); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) shutdown() {
|
||||
s.closeSession()
|
||||
if s.ln != nil {
|
||||
@@ -362,10 +439,6 @@ func (s *Server) handleStream(_ context.Context, stream *smux.Stream) {
|
||||
header = append(header, tmp[:n]...)
|
||||
if req, ok := parseConnectRequest(header); ok {
|
||||
_ = stream.SetReadDeadline(time.Time{})
|
||||
if !s.authorizeRequest(req) {
|
||||
logger.Warnf("sid=%d rejected: client_id mismatch", stream.ID())
|
||||
return
|
||||
}
|
||||
s.dispatch(stream, req)
|
||||
return
|
||||
}
|
||||
@@ -390,8 +463,10 @@ func parseConnectRequest(buf []byte) (ConnectRequest, bool) {
|
||||
return req, true
|
||||
}
|
||||
|
||||
func (s *Server) authorizeRequest(req ConnectRequest) bool {
|
||||
return req.ClientID == s.clientID
|
||||
// defaultAuthHook admits every client and assigns a random session ID.
|
||||
// Replace it via [Config.AuthHook] to plug in real authorization.
|
||||
func defaultAuthHook(_ string, _ map[string]any) (string, error) {
|
||||
return uuid.NewString(), nil
|
||||
}
|
||||
|
||||
func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
|
||||
|
||||
@@ -47,10 +47,9 @@ func TestSmuxConfig(t *testing.T) {
|
||||
|
||||
func TestParseConnectRequest(t *testing.T) {
|
||||
buf, err := json.Marshal(ConnectRequest{
|
||||
Cmd: "connect",
|
||||
ClientID: "client-1", //nolint:goconst // test literal, repetition is intentional
|
||||
Addr: "example.com", //nolint:goconst // test literal, repetition is intentional
|
||||
Port: 443,
|
||||
Cmd: "connect",
|
||||
Addr: "example.com", //nolint:goconst // test literal, repetition is intentional
|
||||
Port: 443,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
@@ -60,7 +59,7 @@ func TestParseConnectRequest(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatal("parseConnectRequest() returned ok=false")
|
||||
}
|
||||
if req.ClientID != "client-1" || req.Addr != "example.com" || req.Port != 443 {
|
||||
if req.Addr != "example.com" || req.Port != 443 {
|
||||
t.Fatalf("parseConnectRequest() = %+v", req)
|
||||
}
|
||||
|
||||
@@ -72,13 +71,13 @@ func TestParseConnectRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeRequest(t *testing.T) {
|
||||
s := &Server{clientID: "client-1"}
|
||||
if !s.authorizeRequest(ConnectRequest{ClientID: "client-1"}) {
|
||||
t.Fatal("authorizeRequest() rejected valid client")
|
||||
func TestDefaultAuthHook(t *testing.T) {
|
||||
sid, err := defaultAuthHook("dev", map[string]any{"x": 1})
|
||||
if err != nil {
|
||||
t.Fatalf("defaultAuthHook() err = %v", err)
|
||||
}
|
||||
if s.authorizeRequest(ConnectRequest{ClientID: "client-2"}) {
|
||||
t.Fatal("authorizeRequest() accepted wrong client")
|
||||
if sid == "" {
|
||||
t.Fatal("defaultAuthHook() returned empty session id")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,7 +300,7 @@ func TestSocks5ConnectTruncatesLongDomain(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamRejectsWrongClientID(t *testing.T) {
|
||||
func TestHandleStreamDispatchAfterConnect(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
defer func() {
|
||||
_ = a.Close()
|
||||
@@ -323,7 +322,7 @@ func TestHandleStreamRejectsWrongClientID(t *testing.T) {
|
||||
go func() {
|
||||
stream, err := serverSess.AcceptStream()
|
||||
if err == nil {
|
||||
(&Server{clientID: "expected"}).handleStream(context.Background(), stream)
|
||||
(&Server{}).handleStream(context.Background(), stream)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
@@ -333,10 +332,9 @@ func TestHandleStreamRejectsWrongClientID(t *testing.T) {
|
||||
t.Fatalf("OpenStream() error = %v", err)
|
||||
}
|
||||
req, err := json.Marshal(ConnectRequest{
|
||||
Cmd: "connect",
|
||||
ClientID: "wrong",
|
||||
Addr: "example.com",
|
||||
Port: 443,
|
||||
Cmd: "connect",
|
||||
Addr: "127.0.0.1",
|
||||
Port: 1, // unreachable port — dispatch will fail dial and exit
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
|
||||
@@ -41,7 +41,7 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
ClientID string
|
||||
DeviceID string
|
||||
Name string
|
||||
OnData func([]byte)
|
||||
DNSServer string
|
||||
|
||||
@@ -40,11 +40,11 @@ func TestNewAndAvailable(t *testing.T) {
|
||||
|
||||
called := false
|
||||
Register("test-transport", func(_ context.Context, cfg Config) (Transport, error) {
|
||||
called = cfg.ClientID == "client-1"
|
||||
called = cfg.DeviceID == "client-1"
|
||||
return &stubTransport{}, nil
|
||||
})
|
||||
|
||||
got, err := New(context.Background(), "test-transport", Config{ClientID: "client-1"})
|
||||
got, err := New(context.Background(), "test-transport", Config{DeviceID: "client-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error)
|
||||
writerDone: make(chan struct{}),
|
||||
frameInterval: time.Second / time.Duration(fps),
|
||||
batchSize: batchSize,
|
||||
bindingToken: bindingToken(cfg.ClientID),
|
||||
bindingToken: bindingToken(cfg.DeviceID),
|
||||
localEpoch: randomEpoch(),
|
||||
}
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ func TestNewConnectSendCallbacksFeaturesAndClose(t *testing.T) {
|
||||
|
||||
trIface, err := New(context.Background(), transport.Config{
|
||||
Carrier: name,
|
||||
ClientID: "client",
|
||||
DeviceID: "client",
|
||||
VP8FPS: 30,
|
||||
VP8BatchSize: 1,
|
||||
})
|
||||
|
||||
@@ -222,7 +222,7 @@ func Check(
|
||||
Carrier: carrierName,
|
||||
RoomURL: buildRoomURL(carrierName, roomID),
|
||||
KeyHex: keyHex,
|
||||
ClientID: clientID,
|
||||
DeviceID: clientID,
|
||||
LocalAddr: fmt.Sprintf("127.0.0.1:%d", socksPort),
|
||||
DNSServer: defaultDNSServer,
|
||||
VP8FPS: clampAtLeastOne(vp8FPS, 120),
|
||||
@@ -305,7 +305,7 @@ func Ping(
|
||||
Carrier: carrierName,
|
||||
RoomURL: buildRoomURL(carrierName, roomID),
|
||||
KeyHex: keyHex,
|
||||
ClientID: clientID,
|
||||
DeviceID: clientID,
|
||||
LocalAddr: fmt.Sprintf("127.0.0.1:%d", socksPort),
|
||||
DNSServer: defaultDNSServer,
|
||||
VP8FPS: clampAtLeastOne(vp8FPS, 120),
|
||||
@@ -550,7 +550,7 @@ func startWithConfig(
|
||||
Carrier: carrierName,
|
||||
RoomURL: roomURL,
|
||||
KeyHex: keyHex,
|
||||
ClientID: clientID,
|
||||
DeviceID: clientID,
|
||||
LocalAddr: fmt.Sprintf("127.0.0.1:%d", socksPort),
|
||||
DNSServer: cfg.dnsServer,
|
||||
SOCKSUser: socksUser,
|
||||
|
||||
@@ -171,10 +171,10 @@ func TestStartWithInjectedRunnerLifecycle(t *testing.T) {
|
||||
|
||||
runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error {
|
||||
if cfg.Link != defaultLink || cfg.Transport != dataTransport || cfg.Carrier != carrierJazz ||
|
||||
cfg.RoomURL != "any" || cfg.ClientID != "client" || cfg.LocalAddr != "127.0.0.1:1080" ||
|
||||
cfg.RoomURL != "any" || cfg.DeviceID != "client" || cfg.LocalAddr != "127.0.0.1:1080" ||
|
||||
cfg.DNSServer != defaultDNSServer || cfg.VP8FPS != 60 || cfg.VP8BatchSize != 8 {
|
||||
t.Fatalf("RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d",
|
||||
cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.ClientID, cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize)
|
||||
cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.DeviceID, cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize)
|
||||
}
|
||||
onReady()
|
||||
<-ctx.Done()
|
||||
|
||||
Reference in New Issue
Block a user