feat: wire WatchConnection into Dial — Read unblocks on session end

Dial now sets SetEndedCallback to close the pipe with ErrSessionEnded
and starts WatchConnection in a goroutine. Consumers (e.g. sing-box)
get a concrete error from Read when the session dies permanently.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
zarazaex69
2026-05-11 14:02:28 +03:00
parent f287dc117a
commit f4ab63b5fa
2 changed files with 72 additions and 6 deletions

View File

@@ -45,6 +45,8 @@ var (
ErrTokenRequired = errors.New("olcrtc: Token required when using direct engine mode")
// ErrRoomCreationUnsupported is returned when the auth provider cannot create rooms.
ErrRoomCreationUnsupported = errors.New("olcrtc: auth provider does not support room creation")
// ErrSessionEnded is returned from Read/Write when the session has ended permanently.
ErrSessionEnded = errors.New("olcrtc: session ended")
)
// Config is the input to [New].
@@ -177,10 +179,16 @@ func newDirect(ctx context.Context, cfg Config) (*Session, error) {
// Dial connects and returns a [net.Conn] backed by the WebRTC data channel.
// It combines [Session.Connect] + wrapping in a single call.
// The connection watcher runs in the background for the lifetime of ctx;
// when the session ends permanently, Read will return an error.
func (s *Session) Dial(ctx context.Context) (net.Conn, error) {
s.inner.SetEndedCallback(func(_ string) {
_ = s.pw.CloseWithError(ErrSessionEnded)
})
if err := s.Connect(ctx); err != nil {
return nil, err
}
go s.inner.WatchConnection(ctx)
return &conn{s: s}, nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/auth"
"github.com/openlibrecommunity/olcrtc/internal/engine"
@@ -18,15 +19,21 @@ const (
// --- stub engine ---
type stubSession struct{ connected bool }
type stubSession struct {
connected bool
onEnded func(string)
watchBlock chan struct{} // closed to unblock WatchConnection
}
func newStubSession() *stubSession { return &stubSession{watchBlock: make(chan struct{})} }
func (s *stubSession) Connect(_ context.Context) error { s.connected = true; return nil }
func (s *stubSession) Send(_ []byte) error { return nil }
func (s *stubSession) Close() error { return nil }
func (s *stubSession) SetReconnectCallback(_ func(*webrtc.DataChannel)) {}
func (s *stubSession) SetShouldReconnect(_ func() bool) {}
func (s *stubSession) SetEndedCallback(_ func(string)) {}
func (s *stubSession) WatchConnection(_ context.Context) {}
func (s *stubSession) SetEndedCallback(cb func(string)) { s.onEnded = cb }
func (s *stubSession) WatchConnection(_ context.Context) { <-s.watchBlock }
func (s *stubSession) CanSend() bool { return s.connected }
func (s *stubSession) GetSendQueue() chan []byte { return nil }
func (s *stubSession) GetBufferedAmount() uint64 { return 0 }
@@ -38,12 +45,24 @@ var _ engine.Session = (*stubSession)(nil)
func registerStubEngine(t *testing.T, name string) {
t.Helper()
engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) {
return &stubSession{}, nil
return newStubSession(), nil
})
t.Cleanup(func() {
// Re-register a no-op so subsequent tests don't break.
engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) {
return &stubSession{}, nil
return newStubSession(), nil
})
})
}
// registerStubEngineControlled registers an engine that returns a pre-built stub the test controls.
func registerStubEngineControlled(t *testing.T, name string, stub *stubSession) {
t.Helper()
engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) {
return stub, nil
})
t.Cleanup(func() {
engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) {
return newStubSession(), nil
})
})
}
@@ -192,6 +211,45 @@ func TestCreateRoom_OK(t *testing.T) {
}
}
func TestDial_ReadUnblocksOnSessionEnd(t *testing.T) {
stub := newStubSession()
registerStubEngineControlled(t, "stub-ended", stub)
sess, err := olcrtc.New(context.Background(), olcrtc.Config{
Engine: "stub-ended",
URL: stubURL,
Token: stubToken,
})
if err != nil {
t.Fatalf("New() error = %v", err)
}
c, err := sess.Dial(context.Background())
if err != nil {
t.Fatalf("Dial() error = %v", err)
}
readErr := make(chan error, 1)
go func() {
buf := make([]byte, 4)
_, err := c.Read(buf)
readErr <- err
}()
// Simulate session ending permanently.
stub.onEnded("test reason")
close(stub.watchBlock)
select {
case err := <-readErr:
if err == nil {
t.Fatal("Read() should return error after session ended")
}
case <-time.After(time.Second):
t.Fatal("Read() did not unblock after session ended")
}
}
func TestDial_RoundTrip(t *testing.T) {
registerStubEngine(t, "stub-dial")