Files
olcrtc/internal/engine/goolom/lifecycle.go
2026-05-16 02:10:34 +03:00

426 lines
9.8 KiB
Go

package goolom
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/openlibrecommunity/olcrtc/internal/engine"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/protect"
"github.com/pion/webrtc/v4"
)
// Connect starts the WebRTC connection process.
func (s *Session) Connect(ctx context.Context) error {
s.closed.Store(false)
s.resetMediaState()
config := webrtc.Configuration{
ICEServers: []webrtc.ICEServer{{URLs: []string{"stun:stun.rtc.yandex.net:3478"}}},
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
}
if err := s.setupPeerConnections(config); err != nil {
return err
}
keepAliveCh, sessionCloseCh := s.resetSession()
var dcReady chan struct{}
if s.onData != nil {
var err error
s.dc, err = s.pcPub.CreateDataChannel("olcrtc", nil)
if err != nil {
return fmt.Errorf("create dc: %w", err)
}
dcReady = make(chan struct{})
s.setupDataChannelHandlers(dcReady, sessionCloseCh)
}
if err := s.dialWebSocket(); err != nil {
return err
}
s.setupICEHandlers()
s.startBackgroundGoroutines(ctx, keepAliveCh)
if s.onData != nil {
select {
case <-dcReady:
return nil
case <-time.After(15 * time.Second):
return ErrDataChannelTimeout
case <-ctx.Done():
return fmt.Errorf("connect context cancelled: %w", ctx.Err())
}
}
return s.waitForMediaReady(ctx, 20*time.Second)
}
func (s *Session) waitForMediaReady(ctx context.Context, timeout time.Duration) error {
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case <-s.subscriberConn:
case <-timer.C:
return ErrSubscriberMediaTimeout
case <-ctx.Done():
return fmt.Errorf("connect context cancelled: %w", ctx.Err())
}
return nil
}
func (s *Session) setupPeerConnections(config webrtc.Configuration) error {
settingEngine := webrtc.SettingEngine{}
if protect.Protector != nil {
settingEngine.SetICEProxyDialer(protect.NewProxyDialer())
}
settingEngine.LoggerFactory = logger.NewPionLoggerFactory()
api := webrtc.NewAPI(webrtc.WithSettingEngine(settingEngine))
var err error
s.pcSub, err = api.NewPeerConnection(config)
if err != nil {
return fmt.Errorf("new sub pc: %w", err)
}
s.pcSub.OnConnectionStateChange(s.onSubscriberConnectionStateChange)
s.pcSub.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
if track.Kind() != webrtc.RTPCodecTypeVideo {
return
}
logger.Infof("goolom remote video track: codec=%s stream=%s track=%s",
track.Codec().MimeType, track.StreamID(), track.ID())
if cb := s.videoTrackHandler(); cb != nil {
cb(track, receiver)
}
})
s.pcPub, err = api.NewPeerConnection(config)
if err != nil {
return fmt.Errorf("new pub pc: %w", err)
}
s.pcPub.OnConnectionStateChange(s.onPublisherConnectionStateChange)
if err := s.attachPendingVideoTracks(); err != nil {
return err
}
return nil
}
func (s *Session) dialWebSocket() error {
wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout)
ws, resp, err := wsDialer.Dial(s.mediaServerURL, nil)
if err != nil {
return fmt.Errorf("dial ws: %w", err)
}
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
s.ws = ws
ws.SetPongHandler(func(string) error {
_ = ws.SetReadDeadline(time.Now().Add(wsReadTimeout))
return nil
})
_ = ws.SetReadDeadline(time.Now().Add(wsReadTimeout))
return nil
}
func (s *Session) startBackgroundGoroutines(ctx context.Context, keepAliveCh chan struct{}) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.keepAlive(keepAliveCh)
}()
_ = s.sendHello()
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleSignaling(ctx)
}()
}
func (s *Session) onConnectionStateChange(state webrtc.PeerConnectionState) {
if !s.closed.Load() && state == webrtc.PeerConnectionStateFailed {
s.queueReconnect()
}
}
func (s *Session) onSubscriberConnectionStateChange(state webrtc.PeerConnectionState) {
logger.Debugf("goolom subscriber state: %s", state.String())
switch state {
case webrtc.PeerConnectionStateConnected:
s.subscriberReady.Store(true)
closeSignal(s.subscriberConn)
case webrtc.PeerConnectionStateDisconnected,
webrtc.PeerConnectionStateFailed,
webrtc.PeerConnectionStateClosed:
s.subscriberReady.Store(false)
case webrtc.PeerConnectionStateUnknown,
webrtc.PeerConnectionStateNew,
webrtc.PeerConnectionStateConnecting:
}
s.onConnectionStateChange(state)
}
func (s *Session) onPublisherConnectionStateChange(state webrtc.PeerConnectionState) {
logger.Debugf("goolom publisher state: %s", state.String())
switch state {
case webrtc.PeerConnectionStateConnected:
s.publisherReady.Store(true)
closeSignal(s.publisherConn)
case webrtc.PeerConnectionStateDisconnected,
webrtc.PeerConnectionStateFailed,
webrtc.PeerConnectionStateClosed:
s.publisherReady.Store(false)
case webrtc.PeerConnectionStateUnknown,
webrtc.PeerConnectionStateNew,
webrtc.PeerConnectionStateConnecting:
}
s.onConnectionStateChange(state)
}
// Close terminates the session and releases resources.
func (s *Session) Close() error {
alreadyClosing := s.closed.Swap(true)
s.sendQueueClosed.Store(true)
if !alreadyClosing {
leaveUID := uuid.New().String()
leaveAck := s.registerAckWaiter(leaveUID)
// 2s matches our jitsi tear-down budget. The reason is the same:
// without giving the server time to register the leave, a
// back-to-back reconnection from the same client collides with a
// still-alive ghost participant on the SFU side and inherits
// stale media-flow state.
if s.sendLeave(leaveUID) {
_ = s.waitForAck(leaveUID, leaveAck, 2*time.Second)
} else {
s.removeAckWaiter(leaveUID)
}
}
closeSignal(s.closeCh)
s.stopSession()
if s.dc != nil {
_ = s.dc.Close()
}
if s.pcPub != nil {
_ = s.pcPub.Close()
}
if s.pcSub != nil {
_ = s.pcSub.Close()
}
if s.ws != nil {
s.wsMu.Lock()
_ = s.ws.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
time.Now().Add(time.Second))
_ = s.ws.Close()
s.wsMu.Unlock()
}
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
}
return nil
}
// WatchConnection monitors the connection lifecycle and reconnects as needed.
func (s *Session) WatchConnection(ctx context.Context) {
const maxReconnects = 10
const reconnectWindow = 5 * time.Minute
for {
select {
case <-ctx.Done():
return
case <-s.closeCh:
return
case <-s.reconnectCh:
if s.handleReconnectAttempt(ctx, maxReconnects, reconnectWindow) {
return
}
}
}
}
func (s *Session) handleReconnectAttempt(ctx context.Context, maxReconnects int, reconnectWindow time.Duration) bool {
if time.Since(s.lastReconnect) > reconnectWindow {
s.reconnectCount = 0
}
s.reconnectCount++
s.lastReconnect = time.Now()
if s.reconnectCount > maxReconnects {
s.signalEnded("reconnect limit reached")
return true
}
backoff := time.Duration(s.reconnectCount) * 2 * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
return s.retryReconnect(ctx, backoff)
}
func (s *Session) retryReconnect(ctx context.Context, backoff time.Duration) bool {
for {
if err := s.reconnect(ctx); err != nil {
logger.Debugf("reconnect failed: %v", err)
select {
case <-ctx.Done():
return true
case <-s.closeCh:
return true
case <-time.After(backoff):
continue
}
}
break
}
return false
}
func (s *Session) reconnect(ctx context.Context) error {
s.reconnecting.Store(true)
defer s.reconnecting.Store(false)
s.sendLeave(uuid.New().String())
time.Sleep(500 * time.Millisecond)
s.stopSession()
if s.dc != nil {
_ = s.dc.Close()
}
if s.pcPub != nil {
_ = s.pcPub.Close()
}
if s.pcSub != nil {
_ = s.pcSub.Close()
}
if s.ws != nil {
s.wsMu.Lock()
_ = s.ws.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
time.Now().Add(time.Second))
_ = s.ws.Close()
s.wsMu.Unlock()
}
if s.onReconnect != nil {
s.onReconnect(nil)
}
time.Sleep(3 * time.Second)
if s.refresh == nil {
return ErrNoRefresh
}
creds, err := s.refresh(ctx)
if err != nil {
return fmt.Errorf("reconnect refresh: %w", err)
}
s.applyRefreshedCredentials(creds)
if err := s.Connect(ctx); err != nil {
return err
}
if s.onReconnect != nil {
s.onReconnect(s.dc)
}
s.drainReconnectQueue()
return nil
}
func (s *Session) applyRefreshedCredentials(creds engine.Credentials) {
if creds.URL != "" {
s.mediaServerURL = creds.URL
}
if creds.Token != "" {
s.peerID = creds.Token
}
if creds.Extra == nil {
return
}
if v := creds.Extra[credentialKeyRoomID]; v != "" {
s.roomID = v
}
if v := creds.Extra[credentialKeyCredentials]; v != "" {
s.credentials = v
}
if v := creds.Extra[credentialKeyRoomURL]; v != "" {
s.roomURL = v
}
if v := creds.Extra[credentialKeyTelemetryReferer]; v != "" {
s.telemetryReferer = v
}
}
func (s *Session) drainReconnectQueue() {
for {
select {
case <-s.reconnectCh:
default:
return
}
}
}
func (s *Session) queueReconnect() {
if s.closed.Load() || s.reconnecting.Load() {
return
}
if s.shouldReconnect != nil && !s.shouldReconnect() {
return
}
select {
case s.reconnectCh <- struct{}{}:
default:
}
}
func (s *Session) stopSession() {
s.stopTelemetry()
s.sessionMu.Lock()
closeSignal(s.keepAliveCh)
closeSignal(s.sessionCloseCh)
s.sessionMu.Unlock()
}
func (s *Session) resetSession() (chan struct{}, chan struct{}) {
s.sessionMu.Lock()
defer s.sessionMu.Unlock()
s.keepAliveCh = make(chan struct{})
s.sessionCloseCh = make(chan struct{})
return s.keepAliveCh, s.sessionCloseCh
}
func (s *Session) resetMediaState() {
s.subscriberReady.Store(false)
s.publisherReady.Store(false)
s.subscriberConn = make(chan struct{})
s.publisherConn = make(chan struct{})
}
func (s *Session) signalEnded(reason string) {
s.closed.Store(true)
s.stopTelemetry()
if s.onEnded != nil {
s.onEnded(reason)
}
}