diff --git a/internal/client/client.go b/internal/client/client.go index dca6c48..cfd489e 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -504,7 +504,11 @@ func (c *Client) shutdown() { } func setupCipher(keyHex string) (*crypto.Cipher, error) { - return runtime.SetupCipher(keyHex) + cipher, err := runtime.SetupCipher(keyHex) + if err != nil { + return nil, fmt.Errorf("client: %w", err) + } + return cipher, nil } func (c *Client) onData(data []byte) { diff --git a/internal/control/control.go b/internal/control/control.go index 24b2974..d208afb 100644 --- a/internal/control/control.go +++ b/internal/control/control.go @@ -309,9 +309,16 @@ func parseMessage(raw []byte) (Message, error) { } func writeFrame(w io.Writer, msg Message) error { - return framing.WriteJSON(w, msg, MaxMessageSize) + if err := framing.WriteJSON(w, msg, MaxMessageSize); err != nil { + return fmt.Errorf("control: %w", err) + } + return nil } func readFrame(r io.Reader) ([]byte, error) { - return framing.ReadBytes(r, MaxMessageSize) + body, err := framing.ReadBytes(r, MaxMessageSize) + if err != nil { + return nil, fmt.Errorf("control: %w", err) + } + return body, nil } diff --git a/internal/handshake/handshake.go b/internal/handshake/handshake.go index 2399c76..3d11422 100644 --- a/internal/handshake/handshake.go +++ b/internal/handshake/handshake.go @@ -192,9 +192,16 @@ func Server(rw io.ReadWriter, auth AuthFunc) (Hello, string, error) { } func writeFrame(w io.Writer, msg any) error { - return framing.WriteJSON(w, msg, MaxMessageSize) + if err := framing.WriteJSON(w, msg, MaxMessageSize); err != nil { + return fmt.Errorf("handshake: %w", err) + } + return nil } func readFrame(r io.Reader) ([]byte, error) { - return framing.ReadBytes(r, MaxMessageSize) + body, err := framing.ReadBytes(r, MaxMessageSize) + if err != nil { + return nil, fmt.Errorf("handshake: %w", err) + } + return body, nil } diff --git a/internal/server/server.go b/internal/server/server.go index df746c3..338d8fd 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -183,7 +183,11 @@ func Run(ctx context.Context, cfg Config) error { } func setupCipher(keyHex string) (*crypto.Cipher, error) { - return runtime.SetupCipher(keyHex) + cipher, err := runtime.SetupCipher(keyHex) + if err != nil { + return nil, fmt.Errorf("server: %w", err) + } + return cipher, nil } func (s *Server) setupResolver() { diff --git a/internal/transport/common/common.go b/internal/transport/common/common.go index 757da4a..5c98fb9 100644 --- a/internal/transport/common/common.go +++ b/internal/transport/common/common.go @@ -114,31 +114,47 @@ func (r *Reassembler) Push(fragment Fragment) (Result, []byte) { return ResultDuplicate, nil } - msg, ok := r.inbound[fragment.Seq] - if !ok || msg.CRC != fragment.CRC || msg.TotalLen != fragment.TotalLen || - len(msg.frags) != int(fragment.FragTotal) { - msg = &InboundMessage{ - TotalLen: fragment.TotalLen, - CRC: fragment.CRC, - frags: make([][]byte, fragment.FragTotal), - remain: int(fragment.FragTotal), - } - r.inbound[fragment.Seq] = msg - } + msg := r.upsert(fragment) if int(fragment.FragIdx) >= len(msg.frags) { return ResultIgnore, nil } - if msg.frags[fragment.FragIdx] == nil { - chunk := make([]byte, len(fragment.Payload)) - copy(chunk, fragment.Payload) - msg.frags[fragment.FragIdx] = chunk - msg.remain-- - } + r.storeChunk(msg, fragment) if msg.remain > 0 { return ResultPartial, nil } + return r.deliver(fragment.Seq, msg) +} - delete(r.inbound, fragment.Seq) +// upsert returns the inbound message tracking entry for fragment.Seq, +// creating a fresh entry if no compatible one is present. +func (r *Reassembler) upsert(fragment Fragment) *InboundMessage { + msg, ok := r.inbound[fragment.Seq] + if ok && msg.CRC == fragment.CRC && msg.TotalLen == fragment.TotalLen && + len(msg.frags) == int(fragment.FragTotal) { + return msg + } + msg = &InboundMessage{ + TotalLen: fragment.TotalLen, + CRC: fragment.CRC, + frags: make([][]byte, fragment.FragTotal), + remain: int(fragment.FragTotal), + } + r.inbound[fragment.Seq] = msg + return msg +} + +func (r *Reassembler) storeChunk(msg *InboundMessage, fragment Fragment) { + if msg.frags[fragment.FragIdx] != nil { + return + } + chunk := make([]byte, len(fragment.Payload)) + copy(chunk, fragment.Payload) + msg.frags[fragment.FragIdx] = chunk + msg.remain-- +} + +func (r *Reassembler) deliver(seq uint32, msg *InboundMessage) (Result, []byte) { + delete(r.inbound, seq) data := assemble(msg) if crc32.ChecksumIEEE(data) != msg.CRC { return ResultIgnore, nil @@ -146,7 +162,7 @@ func (r *Reassembler) Push(fragment Fragment) (Result, []byte) { if len(r.delivered) > r.maxRecent { r.delivered = make(map[uint32]uint32) } - r.delivered[fragment.Seq] = msg.CRC + r.delivered[seq] = msg.CRC return ResultDelivered, data } diff --git a/internal/transport/common/common_test.go b/internal/transport/common/common_test.go index 1080be4..5b89e3d 100644 --- a/internal/transport/common/common_test.go +++ b/internal/transport/common/common_test.go @@ -43,9 +43,9 @@ func TestReassemblerDeliveredAndDuplicate(t *testing.T) { result, data := r.Push(common.Fragment{ Seq: 1, CRC: crc, - TotalLen: uint32(len(payload)), + TotalLen: uint32(len(payload)), //nolint:gosec // bounded test fixture FragIdx: uint16(i), - FragTotal: uint16(len(frags)), + FragTotal: uint16(len(frags)), //nolint:gosec // bounded test fixture Payload: frag, }) if i < len(frags)-1 { @@ -63,9 +63,9 @@ func TestReassemblerDeliveredAndDuplicate(t *testing.T) { result, _ := r.Push(common.Fragment{ Seq: 1, CRC: crc, - TotalLen: uint32(len(payload)), - FragIdx: uint16(len(frags) - 1), - FragTotal: uint16(len(frags)), + TotalLen: uint32(len(payload)), //nolint:gosec // bounded test fixture + FragIdx: uint16(len(frags) - 1), //nolint:gosec // bounded test fixture + FragTotal: uint16(len(frags)), //nolint:gosec // bounded test fixture Payload: frags[len(frags)-1], }) if result != common.ResultDuplicate { @@ -80,9 +80,9 @@ func TestReassemblerIgnoresCRCMismatch(t *testing.T) { result, _ := r.Push(common.Fragment{ Seq: 1, CRC: 0xdeadbeef, // wrong - TotalLen: uint32(len(payload)), + TotalLen: uint32(len(payload)), //nolint:gosec // bounded test fixture FragIdx: 0, - FragTotal: uint16(len(frags)), + FragTotal: uint16(len(frags)), //nolint:gosec // bounded test fixture Payload: frags[0], }) if result != common.ResultDelivered { diff --git a/internal/transport/datachannel/transport_test.go b/internal/transport/datachannel/transport_test.go index 3113f4b..6deba5c 100644 --- a/internal/transport/datachannel/transport_test.go +++ b/internal/transport/datachannel/transport_test.go @@ -100,13 +100,15 @@ func TestNewAndFeatures(t *testing.T) { func TestNewErrorPaths(t *testing.T) { registerCarrier("datachannel-fail-create", nil, errDCBoom) - if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-fail-create"}); err == nil || err.Error() != "open engine session: boom" { + _, err := New(context.Background(), transport.Config{Carrier: "datachannel-fail-create"}) + if err == nil || err.Error() != "open engine session: boom" { t.Fatalf("New() error = %v", err) } nonByteStream := &stubSession{caps: engine.Capabilities{}} registerCarrier("datachannel-no-stream", nonByteStream, nil) - if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-no-stream"}); !errors.Is(err, ErrByteStreamUnsupported) { + _, err = New(context.Background(), transport.Config{Carrier: "datachannel-no-stream"}) + if !errors.Is(err, ErrByteStreamUnsupported) { t.Fatalf("New() error = %v, want %v", err, ErrByteStreamUnsupported) } } diff --git a/internal/transport/seichannel/options.go b/internal/transport/seichannel/options.go index 43f3eba..528640c 100644 --- a/internal/transport/seichannel/options.go +++ b/internal/transport/seichannel/options.go @@ -2,6 +2,7 @@ package seichannel import ( "fmt" + "time" "github.com/openlibrecommunity/olcrtc/internal/transport" ) @@ -17,6 +18,23 @@ type Options struct { // TransportOptions marks Options as belonging to the transport options family. func (Options) TransportOptions() {} +// withDefaults fills unset Options fields with the package defaults. +func (o Options) withDefaults() Options { + if o.FPS <= 0 { + o.FPS = defaultFPS + } + if o.BatchSize <= 0 { + o.BatchSize = defaultBatchSize + } + if o.FragmentSize <= 0 { + o.FragmentSize = defaultFragmentSize + } + if o.AckTimeoutMS <= 0 { + o.AckTimeoutMS = int(defaultAckTimeout / time.Millisecond) + } + return o +} + func optionsFrom(cfg transport.Config) (Options, error) { if cfg.Options == nil { return Options{}, nil diff --git a/internal/transport/seichannel/transport.go b/internal/transport/seichannel/transport.go index 4f49c97..eea5259 100644 --- a/internal/transport/seichannel/transport.go +++ b/internal/transport/seichannel/transport.go @@ -150,23 +150,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) return nil, fmt.Errorf("create local video track: %w", err) } - fps := opts.FPS - if fps <= 0 { - fps = defaultFPS - } - batchSize := opts.BatchSize - if batchSize <= 0 { - batchSize = defaultBatchSize - } - fragmentSize := opts.FragmentSize - if fragmentSize <= 0 { - fragmentSize = defaultFragmentSize - } - ackTimeout := defaultAckTimeout - if opts.AckTimeoutMS > 0 { - ackTimeout = time.Duration(opts.AckTimeoutMS) * time.Millisecond - } - + opts = opts.withDefaults() tr := &streamTransport{ stream: stream, track: track, @@ -177,10 +161,10 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) writerDone: make(chan struct{}), acks: common.NewAckRegistry(), reassembler: common.NewReassembler(256), - fragmentSize: fragmentSize, - ackTimeout: ackTimeout, - frameInterval: time.Second / time.Duration(fps), - batchSize: batchSize, + fragmentSize: opts.FragmentSize, + ackTimeout: time.Duration(opts.AckTimeoutMS) * time.Millisecond, + frameInterval: time.Second / time.Duration(opts.FPS), + batchSize: opts.BatchSize, } if err := stream.AddTrack(track); err != nil { @@ -470,8 +454,8 @@ func (p *streamTransport) handleInboundFrame(frame transportFrame) { p.onData(data) } p.sendAck(frame.seq, frame.crc) - default: - // Partial or Ignore: do nothing. + case common.ResultPartial, common.ResultIgnore: + // fragment stored or discarded; no peer response needed yet. } } diff --git a/internal/transport/seichannel/transport_unit_test.go b/internal/transport/seichannel/transport_unit_test.go index ed8b53a..f9d90ba 100644 --- a/internal/transport/seichannel/transport_unit_test.go +++ b/internal/transport/seichannel/transport_unit_test.go @@ -148,14 +148,16 @@ func TestNewErrorPaths(t *testing.T) { enginebuiltin.Register("seichannel-create-fails", func(context.Context, enginebuiltin.Config) (engine.Session, error) { return nil, errBoom }) - if _, err := New(context.Background(), transport.Config{Carrier: "seichannel-create-fails"}); err == nil || err.Error() != "open engine session: boom" { //nolint:lll // long test description + _, err := New(context.Background(), transport.Config{Carrier: "seichannel-create-fails"}) + if err == nil || err.Error() != "open engine session: boom" { t.Fatalf("New() error = %v", err) } enginebuiltin.Register("seichannel-no-video", func(context.Context, enginebuiltin.Config) (engine.Session, error) { return &fakeEngineSession{stream: &fakeVideoStream{}, noVideo: true}, nil }) - if _, err := New(context.Background(), transport.Config{Carrier: "seichannel-no-video"}); !errors.Is(err, ErrVideoTrackUnsupported) { + _, err = New(context.Background(), transport.Config{Carrier: "seichannel-no-video"}) + if !errors.Is(err, ErrVideoTrackUnsupported) { t.Fatalf("New() error = %v, want %v", err, ErrVideoTrackUnsupported) } } diff --git a/internal/transport/videochannel/transport.go b/internal/transport/videochannel/transport.go index 8974e47..44fbb60 100644 --- a/internal/transport/videochannel/transport.go +++ b/internal/transport/videochannel/transport.go @@ -125,7 +125,9 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) // Stream/track IDs must be unique per peer: Jitsi/Jicofo keys participant // sources by msid (stream-id+track-id) and rejects a session-accept whose // msid collides with one already in the conference. - track, err := webrtc.NewTrackLocalStaticSample(codec.capability, "videochannel-"+common.RandomID(), "olcrtc-"+common.RandomID()) + streamID := "videochannel-" + common.RandomID() + trackID := "olcrtc-" + common.RandomID() + track, err := webrtc.NewTrackLocalStaticSample(codec.capability, streamID, trackID) if err != nil { return nil, fmt.Errorf("create local video track: %w", err) } @@ -580,8 +582,8 @@ func (p *streamTransport) handleInboundFrame(frame transportFrame) { p.onData(data) } p.sendAck(frame.seq, frame.crc) - default: - // Partial or Ignore: do nothing. + case common.ResultPartial, common.ResultIgnore: + // fragment stored or discarded; no peer response needed yet. } } diff --git a/internal/transport/videochannel/transport_unit_test.go b/internal/transport/videochannel/transport_unit_test.go index 35a60f8..623f9f9 100644 --- a/internal/transport/videochannel/transport_unit_test.go +++ b/internal/transport/videochannel/transport_unit_test.go @@ -129,17 +129,22 @@ func TestNewCallbacksFeaturesAndClose(t *testing.T) { } func TestNewErrorPaths(t *testing.T) { - enginebuiltin.Register("videochannel-create-fails", func(context.Context, enginebuiltin.Config) (engine.Session, error) { - return nil, errVideoUnitBoom - }) - if _, err := New(context.Background(), transport.Config{Carrier: "videochannel-create-fails"}); err == nil || err.Error() != "open engine session: boom" { //nolint:lll // long test description + enginebuiltin.Register( + "videochannel-create-fails", + func(context.Context, enginebuiltin.Config) (engine.Session, error) { + return nil, errVideoUnitBoom + }, + ) + _, err := New(context.Background(), transport.Config{Carrier: "videochannel-create-fails"}) + if err == nil || err.Error() != "open engine session: boom" { t.Fatalf("New() error = %v", err) } enginebuiltin.Register("videochannel-no-video", func(context.Context, enginebuiltin.Config) (engine.Session, error) { return &fakeEngineSession{stream: &fakeVideoStream{}, noVideo: true}, nil }) - if _, err := New(context.Background(), transport.Config{Carrier: "videochannel-no-video"}); !errors.Is(err, ErrVideoTrackUnsupported) { + _, err = New(context.Background(), transport.Config{Carrier: "videochannel-no-video"}) + if !errors.Is(err, ErrVideoTrackUnsupported) { t.Fatalf("New() error = %v, want %v", err, ErrVideoTrackUnsupported) } } diff --git a/internal/transport/vp8channel/transport_unit_test.go b/internal/transport/vp8channel/transport_unit_test.go index 7821232..6cd97a5 100644 --- a/internal/transport/vp8channel/transport_unit_test.go +++ b/internal/transport/vp8channel/transport_unit_test.go @@ -169,14 +169,16 @@ func TestNewErrorPaths(t *testing.T) { enginebuiltin.Register("vp8channel-create-fails", func(context.Context, enginebuiltin.Config) (engine.Session, error) { return nil, errVP8UnitBoom }) - if _, err := New(context.Background(), transport.Config{Carrier: "vp8channel-create-fails"}); err == nil || err.Error() != "open engine session: boom" { //nolint:lll // long test description + _, err := New(context.Background(), transport.Config{Carrier: "vp8channel-create-fails"}) + if err == nil || err.Error() != "open engine session: boom" { t.Fatalf("New() error = %v", err) } enginebuiltin.Register("vp8channel-no-video", func(context.Context, enginebuiltin.Config) (engine.Session, error) { return &fakeEngineSession{stream: &fakeVideoStream{}, noVideo: true}, nil }) - if _, err := New(context.Background(), transport.Config{Carrier: "vp8channel-no-video"}); !errors.Is(err, ErrVideoTrackUnsupported) { + _, err = New(context.Background(), transport.Config{Carrier: "vp8channel-no-video"}) + if !errors.Is(err, ErrVideoTrackUnsupported) { t.Fatalf("New() error = %v, want %v", err, ErrVideoTrackUnsupported) } } diff --git a/mobile/mobile_test.go b/mobile/mobile_test.go index 0c81b84..75c4810 100644 --- a/mobile/mobile_test.go +++ b/mobile/mobile_test.go @@ -78,7 +78,6 @@ func TestProtectorAndLogging(t *testing.T) { } } -//nolint:cyclop // compact setter smoke test verifies several related defaults together func TestDefaultsAndSetters(t *testing.T) { resetMobileGlobals(t)