mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-06-06 20:39:47 +00:00
feat: wire WatchConnection into Dial — Read unblocks on session end
Dial now sets SetEndedCallback to close the pipe with ErrSessionEnded and starts WatchConnection in a goroutine. Consumers (e.g. sing-box) get a concrete error from Read when the session dies permanently. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -45,6 +45,8 @@ var (
|
||||
ErrTokenRequired = errors.New("olcrtc: Token required when using direct engine mode")
|
||||
// ErrRoomCreationUnsupported is returned when the auth provider cannot create rooms.
|
||||
ErrRoomCreationUnsupported = errors.New("olcrtc: auth provider does not support room creation")
|
||||
// ErrSessionEnded is returned from Read/Write when the session has ended permanently.
|
||||
ErrSessionEnded = errors.New("olcrtc: session ended")
|
||||
)
|
||||
|
||||
// Config is the input to [New].
|
||||
@@ -177,10 +179,16 @@ func newDirect(ctx context.Context, cfg Config) (*Session, error) {
|
||||
|
||||
// Dial connects and returns a [net.Conn] backed by the WebRTC data channel.
|
||||
// It combines [Session.Connect] + wrapping in a single call.
|
||||
// The connection watcher runs in the background for the lifetime of ctx;
|
||||
// when the session ends permanently, Read will return an error.
|
||||
func (s *Session) Dial(ctx context.Context) (net.Conn, error) {
|
||||
s.inner.SetEndedCallback(func(_ string) {
|
||||
_ = s.pw.CloseWithError(ErrSessionEnded)
|
||||
})
|
||||
if err := s.Connect(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.inner.WatchConnection(ctx)
|
||||
return &conn{s: s}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/auth"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/engine"
|
||||
@@ -18,15 +19,21 @@ const (
|
||||
|
||||
// --- stub engine ---
|
||||
|
||||
type stubSession struct{ connected bool }
|
||||
type stubSession struct {
|
||||
connected bool
|
||||
onEnded func(string)
|
||||
watchBlock chan struct{} // closed to unblock WatchConnection
|
||||
}
|
||||
|
||||
func newStubSession() *stubSession { return &stubSession{watchBlock: make(chan struct{})} }
|
||||
|
||||
func (s *stubSession) Connect(_ context.Context) error { s.connected = true; return nil }
|
||||
func (s *stubSession) Send(_ []byte) error { return nil }
|
||||
func (s *stubSession) Close() error { return nil }
|
||||
func (s *stubSession) SetReconnectCallback(_ func(*webrtc.DataChannel)) {}
|
||||
func (s *stubSession) SetShouldReconnect(_ func() bool) {}
|
||||
func (s *stubSession) SetEndedCallback(_ func(string)) {}
|
||||
func (s *stubSession) WatchConnection(_ context.Context) {}
|
||||
func (s *stubSession) SetEndedCallback(cb func(string)) { s.onEnded = cb }
|
||||
func (s *stubSession) WatchConnection(_ context.Context) { <-s.watchBlock }
|
||||
func (s *stubSession) CanSend() bool { return s.connected }
|
||||
func (s *stubSession) GetSendQueue() chan []byte { return nil }
|
||||
func (s *stubSession) GetBufferedAmount() uint64 { return 0 }
|
||||
@@ -38,12 +45,24 @@ var _ engine.Session = (*stubSession)(nil)
|
||||
func registerStubEngine(t *testing.T, name string) {
|
||||
t.Helper()
|
||||
engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) {
|
||||
return &stubSession{}, nil
|
||||
return newStubSession(), nil
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
// Re-register a no-op so subsequent tests don't break.
|
||||
engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) {
|
||||
return &stubSession{}, nil
|
||||
return newStubSession(), nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// registerStubEngineControlled registers an engine that returns a pre-built stub the test controls.
|
||||
func registerStubEngineControlled(t *testing.T, name string, stub *stubSession) {
|
||||
t.Helper()
|
||||
engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) {
|
||||
return stub, nil
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
engine.Register(name, func(_ context.Context, _ engine.Config) (engine.Session, error) {
|
||||
return newStubSession(), nil
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -192,6 +211,45 @@ func TestCreateRoom_OK(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDial_ReadUnblocksOnSessionEnd(t *testing.T) {
|
||||
stub := newStubSession()
|
||||
registerStubEngineControlled(t, "stub-ended", stub)
|
||||
|
||||
sess, err := olcrtc.New(context.Background(), olcrtc.Config{
|
||||
Engine: "stub-ended",
|
||||
URL: stubURL,
|
||||
Token: stubToken,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
c, err := sess.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial() error = %v", err)
|
||||
}
|
||||
|
||||
readErr := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 4)
|
||||
_, err := c.Read(buf)
|
||||
readErr <- err
|
||||
}()
|
||||
|
||||
// Simulate session ending permanently.
|
||||
stub.onEnded("test reason")
|
||||
close(stub.watchBlock)
|
||||
|
||||
select {
|
||||
case err := <-readErr:
|
||||
if err == nil {
|
||||
t.Fatal("Read() should return error after session ended")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Read() did not unblock after session ended")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDial_RoundTrip(t *testing.T) {
|
||||
registerStubEngine(t, "stub-dial")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user