mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
feat: remove unused client ID from config
This commit is contained in:
@@ -13,7 +13,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
@@ -43,7 +45,9 @@ type Server struct {
|
||||
sessMu sync.RWMutex
|
||||
reinstallMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
clientID string
|
||||
authHook handshake.AuthFunc
|
||||
deviceID string
|
||||
sessionID string
|
||||
dnsServer string
|
||||
resolver *net.Resolver
|
||||
socksProxyAddr string
|
||||
@@ -52,10 +56,9 @@ type Server struct {
|
||||
|
||||
// ConnectRequest is a message from the client to establish a new connection.
|
||||
type ConnectRequest struct {
|
||||
Cmd string `json:"cmd"`
|
||||
ClientID string `json:"clientId"`
|
||||
Addr string `json:"addr"`
|
||||
Port int `json:"port"`
|
||||
Cmd string `json:"cmd"`
|
||||
Addr string `json:"addr"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
// Config holds runtime configuration for [Run].
|
||||
@@ -65,7 +68,6 @@ type Config struct {
|
||||
Carrier string
|
||||
RoomURL string
|
||||
KeyHex string
|
||||
ClientID string
|
||||
DNSServer string
|
||||
SOCKSProxyAddr string
|
||||
SOCKSProxyPort int
|
||||
@@ -88,6 +90,10 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
|
||||
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
|
||||
// return a session ID. If nil, every client is admitted with a random UUID.
|
||||
AuthHook handshake.AuthFunc
|
||||
}
|
||||
|
||||
// Run starts the server with the given configuration.
|
||||
@@ -100,9 +106,14 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
return fmt.Errorf("setupCipher failed: %w", err)
|
||||
}
|
||||
|
||||
hook := cfg.AuthHook
|
||||
if hook == nil {
|
||||
hook = defaultAuthHook
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cipher: cipher,
|
||||
clientID: cfg.ClientID,
|
||||
authHook: hook,
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksProxyAddr: cfg.SOCKSProxyAddr,
|
||||
socksProxyPort: cfg.SOCKSProxyPort,
|
||||
@@ -182,7 +193,7 @@ func (s *Server) bringUpLink(
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
ClientID: s.clientID,
|
||||
DeviceID: "",
|
||||
Name: names.Generate(),
|
||||
OnData: s.onData,
|
||||
DNSServer: s.dnsServer,
|
||||
@@ -270,6 +281,8 @@ func (s *Server) reinstallSession(dead *smux.Session) {
|
||||
_ = s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
s.installSession()
|
||||
}
|
||||
@@ -284,6 +297,8 @@ func (s *Server) closeSession() {
|
||||
_ = s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -296,9 +311,9 @@ func (s *Server) onData(data []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
// serve drives the smux Accept loop, spawning a tunnel per inbound stream.
|
||||
// The loop tolerates session bounces (reconnects) by waiting until a fresh
|
||||
// session is installed instead of terminating the server.
|
||||
// serve drives the smux Accept loop. The first accepted stream on a given
|
||||
// smux session is the control stream — the handshake runs there. Subsequent
|
||||
// streams are tunnel streams and proxy traffic.
|
||||
func (s *Server) serve(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
@@ -319,6 +334,12 @@ func (s *Server) serve(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if !s.handshakeReady() {
|
||||
if !s.acceptHandshake(ctx, sess) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
stream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
select {
|
||||
@@ -339,6 +360,62 @@ func (s *Server) serve(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// handshakeReady reports whether the current session has completed its
|
||||
// handshake. The session is reset on reconnect, so this is recomputed.
|
||||
func (s *Server) handshakeReady() bool {
|
||||
s.sessMu.RLock()
|
||||
defer s.sessMu.RUnlock()
|
||||
return s.sessionID != ""
|
||||
}
|
||||
|
||||
func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
|
||||
stream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
default:
|
||||
}
|
||||
logger.Debugf("AcceptStream(control) returned %v - reinstalling session", err)
|
||||
s.reinstallSession(sess)
|
||||
return false
|
||||
}
|
||||
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
|
||||
hello, sid, err := handshake.Server(stream, s.authHook)
|
||||
_ = stream.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
logger.Warnf("handshake failed: %v", err)
|
||||
_ = stream.Close()
|
||||
s.reinstallSession(sess)
|
||||
return false
|
||||
}
|
||||
s.sessMu.Lock()
|
||||
s.deviceID = hello.DeviceID
|
||||
s.sessionID = sid
|
||||
s.sessMu.Unlock()
|
||||
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
|
||||
// The control stream stays open for the lifetime of the session;
|
||||
// keep it parked in a goroutine so the smux session does not close it.
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.parkControlStream(stream)
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// parkControlStream blocks reading from the control stream until it closes.
|
||||
// Future control messages (kick, rate updates, etc.) would be dispatched here.
|
||||
func (s *Server) parkControlStream(stream *smux.Stream) {
|
||||
defer func() { _ = stream.Close() }()
|
||||
buf := make([]byte, 64)
|
||||
for {
|
||||
if _, err := stream.Read(buf); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) shutdown() {
|
||||
s.closeSession()
|
||||
if s.ln != nil {
|
||||
@@ -362,10 +439,6 @@ func (s *Server) handleStream(_ context.Context, stream *smux.Stream) {
|
||||
header = append(header, tmp[:n]...)
|
||||
if req, ok := parseConnectRequest(header); ok {
|
||||
_ = stream.SetReadDeadline(time.Time{})
|
||||
if !s.authorizeRequest(req) {
|
||||
logger.Warnf("sid=%d rejected: client_id mismatch", stream.ID())
|
||||
return
|
||||
}
|
||||
s.dispatch(stream, req)
|
||||
return
|
||||
}
|
||||
@@ -390,8 +463,10 @@ func parseConnectRequest(buf []byte) (ConnectRequest, bool) {
|
||||
return req, true
|
||||
}
|
||||
|
||||
func (s *Server) authorizeRequest(req ConnectRequest) bool {
|
||||
return req.ClientID == s.clientID
|
||||
// defaultAuthHook admits every client and assigns a random session ID.
|
||||
// Replace it via [Config.AuthHook] to plug in real authorization.
|
||||
func defaultAuthHook(_ string, _ map[string]any) (string, error) {
|
||||
return uuid.NewString(), nil
|
||||
}
|
||||
|
||||
func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
|
||||
|
||||
@@ -47,10 +47,9 @@ func TestSmuxConfig(t *testing.T) {
|
||||
|
||||
func TestParseConnectRequest(t *testing.T) {
|
||||
buf, err := json.Marshal(ConnectRequest{
|
||||
Cmd: "connect",
|
||||
ClientID: "client-1", //nolint:goconst // test literal, repetition is intentional
|
||||
Addr: "example.com", //nolint:goconst // test literal, repetition is intentional
|
||||
Port: 443,
|
||||
Cmd: "connect",
|
||||
Addr: "example.com", //nolint:goconst // test literal, repetition is intentional
|
||||
Port: 443,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
@@ -60,7 +59,7 @@ func TestParseConnectRequest(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatal("parseConnectRequest() returned ok=false")
|
||||
}
|
||||
if req.ClientID != "client-1" || req.Addr != "example.com" || req.Port != 443 {
|
||||
if req.Addr != "example.com" || req.Port != 443 {
|
||||
t.Fatalf("parseConnectRequest() = %+v", req)
|
||||
}
|
||||
|
||||
@@ -72,13 +71,13 @@ func TestParseConnectRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeRequest(t *testing.T) {
|
||||
s := &Server{clientID: "client-1"}
|
||||
if !s.authorizeRequest(ConnectRequest{ClientID: "client-1"}) {
|
||||
t.Fatal("authorizeRequest() rejected valid client")
|
||||
func TestDefaultAuthHook(t *testing.T) {
|
||||
sid, err := defaultAuthHook("dev", map[string]any{"x": 1})
|
||||
if err != nil {
|
||||
t.Fatalf("defaultAuthHook() err = %v", err)
|
||||
}
|
||||
if s.authorizeRequest(ConnectRequest{ClientID: "client-2"}) {
|
||||
t.Fatal("authorizeRequest() accepted wrong client")
|
||||
if sid == "" {
|
||||
t.Fatal("defaultAuthHook() returned empty session id")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,7 +300,7 @@ func TestSocks5ConnectTruncatesLongDomain(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamRejectsWrongClientID(t *testing.T) {
|
||||
func TestHandleStreamDispatchAfterConnect(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
defer func() {
|
||||
_ = a.Close()
|
||||
@@ -323,7 +322,7 @@ func TestHandleStreamRejectsWrongClientID(t *testing.T) {
|
||||
go func() {
|
||||
stream, err := serverSess.AcceptStream()
|
||||
if err == nil {
|
||||
(&Server{clientID: "expected"}).handleStream(context.Background(), stream)
|
||||
(&Server{}).handleStream(context.Background(), stream)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
@@ -333,10 +332,9 @@ func TestHandleStreamRejectsWrongClientID(t *testing.T) {
|
||||
t.Fatalf("OpenStream() error = %v", err)
|
||||
}
|
||||
req, err := json.Marshal(ConnectRequest{
|
||||
Cmd: "connect",
|
||||
ClientID: "wrong",
|
||||
Addr: "example.com",
|
||||
Port: 443,
|
||||
Cmd: "connect",
|
||||
Addr: "127.0.0.1",
|
||||
Port: 1, // unreachable port — dispatch will fail dial and exit
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
|
||||
Reference in New Issue
Block a user