From f4ab63b5fa0af72936264d96cade367cdc7884c1 Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Mon, 11 May 2026 14:02:28 +0300 Subject: [PATCH] =?UTF-8?q?feat:=20wire=20WatchConnection=20into=20Dial=20?= =?UTF-8?q?=E2=80=94=20Read=20unblocks=20on=20session=20end?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dial now sets SetEndedCallback to close the pipe with ErrSessionEnded and starts WatchConnection in a goroutine. Consumers (e.g. sing-box) get a concrete error from Read when the session dies permanently. Co-Authored-By: Claude Sonnet 4.6 --- pkg/olcrtc/olcrtc.go | 8 +++++ pkg/olcrtc/olcrtc_test.go | 70 +++++++++++++++++++++++++++++++++++---- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/pkg/olcrtc/olcrtc.go b/pkg/olcrtc/olcrtc.go index 6b2c9d9..7999d63 100644 --- a/pkg/olcrtc/olcrtc.go +++ b/pkg/olcrtc/olcrtc.go @@ -45,6 +45,8 @@ var ( ErrTokenRequired = errors.New("olcrtc: Token required when using direct engine mode") // ErrRoomCreationUnsupported is returned when the auth provider cannot create rooms. ErrRoomCreationUnsupported = errors.New("olcrtc: auth provider does not support room creation") + // ErrSessionEnded is returned from Read/Write when the session has ended permanently. + ErrSessionEnded = errors.New("olcrtc: session ended") ) // Config is the input to [New]. @@ -177,10 +179,16 @@ func newDirect(ctx context.Context, cfg Config) (*Session, error) { // Dial connects and returns a [net.Conn] backed by the WebRTC data channel. // It combines [Session.Connect] + wrapping in a single call. +// The connection watcher runs in the background for the lifetime of ctx; +// when the session ends permanently, Read will return an error. func (s *Session) Dial(ctx context.Context) (net.Conn, error) { + s.inner.SetEndedCallback(func(_ string) { + _ = s.pw.CloseWithError(ErrSessionEnded) + }) if err := s.Connect(ctx); err != nil { return nil, err } + go s.inner.WatchConnection(ctx) return &conn{s: s}, nil } diff --git a/pkg/olcrtc/olcrtc_test.go b/pkg/olcrtc/olcrtc_test.go index 47af01e..27ca8a4 100644 --- a/pkg/olcrtc/olcrtc_test.go +++ b/pkg/olcrtc/olcrtc_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "github.com/openlibrecommunity/olcrtc/internal/auth" "github.com/openlibrecommunity/olcrtc/internal/engine" @@ -18,15 +19,21 @@ const ( // --- stub engine --- -type stubSession struct{ connected bool } +type stubSession struct { + connected bool + onEnded func(string) + watchBlock chan struct{} // closed to unblock WatchConnection +} + +func newStubSession() *stubSession { return &stubSession{watchBlock: make(chan struct{})} } func (s *stubSession) Connect(_ context.Context) error { s.connected = true; return nil } func (s *stubSession) Send(_ []byte) error { return nil } func (s *stubSession) Close() error { return nil } func (s *stubSession) SetReconnectCallback(_ func(*webrtc.DataChannel)) {} func (s *stubSession) SetShouldReconnect(_ func() bool) {} -func (s *stubSession) SetEndedCallback(_ func(string)) {} -func (s *stubSession) WatchConnection(_ context.Context) {} +func (s *stubSession) SetEndedCallback(cb func(string)) { s.onEnded = cb } +func (s *stubSession) WatchConnection(_ context.Context) { <-s.watchBlock } func (s *stubSession) CanSend() bool { return s.connected } func (s *stubSession) GetSendQueue() chan []byte { return nil } func (s *stubSession) GetBufferedAmount() uint64 { return 0 } @@ -38,12 +45,24 @@ var _ engine.Session = (*stubSession)(nil) func registerStubEngine(t *testing.T, name string) { t.Helper() engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) { - return &stubSession{}, nil + return newStubSession(), nil }) t.Cleanup(func() { - // Re-register a no-op so subsequent tests don't break. engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) { - return &stubSession{}, nil + return newStubSession(), nil + }) + }) +} + +// registerStubEngineControlled registers an engine that returns a pre-built stub the test controls. +func registerStubEngineControlled(t *testing.T, name string, stub *stubSession) { + t.Helper() + engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) { + return stub, nil + }) + t.Cleanup(func() { + engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) { + return newStubSession(), nil }) }) } @@ -192,6 +211,45 @@ func TestCreateRoom_OK(t *testing.T) { } } +func TestDial_ReadUnblocksOnSessionEnd(t *testing.T) { + stub := newStubSession() + registerStubEngineControlled(t, "stub-ended", stub) + + sess, err := olcrtc.New(context.Background(), olcrtc.Config{ + Engine: "stub-ended", + URL: stubURL, + Token: stubToken, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + c, err := sess.Dial(context.Background()) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + + readErr := make(chan error, 1) + go func() { + buf := make([]byte, 4) + _, err := c.Read(buf) + readErr <- err + }() + + // Simulate session ending permanently. + stub.onEnded("test reason") + close(stub.watchBlock) + + select { + case err := <-readErr: + if err == nil { + t.Fatal("Read() should return error after session ended") + } + case <-time.After(time.Second): + t.Fatal("Read() did not unblock after session ended") + } +} + func TestDial_RoundTrip(t *testing.T) { registerStubEngine(t, "stub-dial")