diff --git a/internal/e2e/tunnel_test.go b/internal/e2e/tunnel_test.go index b5cf0dd..deb9f44 100644 --- a/internal/e2e/tunnel_test.go +++ b/internal/e2e/tunnel_test.go @@ -1021,9 +1021,7 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) { if err := ln.Connect(context.Background()); err != nil { t.Fatalf("Connect() error = %v", err) } - if !ln.CanSend() { - t.Fatal("CanSend() = false, want true") - } + assertLinkCanSendAfterConnect(t, ln, transportName) if err := ln.Close(); err != nil { t.Fatalf("Close() error = %v", err) } @@ -1033,6 +1031,20 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) { } } +func assertLinkCanSendAfterConnect(t *testing.T, ln link.Link, transportName string) { + t.Helper() + + if transportName == transportSEI { + if ln.CanSend() { + t.Fatal("CanSend() = true before peer seichannel frame") + } + return + } + if !ln.CanSend() { + t.Fatal("CanSend() = false, want true") + } +} + //nolint:cyclop // table-driven test naturally has many branches func TestRealProviderTransportMatrix(t *testing.T) { if !*realE2E { diff --git a/internal/transport/seichannel/transport.go b/internal/transport/seichannel/transport.go index 6cb7f9b..73b54f9 100644 --- a/internal/transport/seichannel/transport.go +++ b/internal/transport/seichannel/transport.go @@ -35,6 +35,7 @@ const ( protocolVersion byte = 1 frameTypeData byte = 1 frameTypeAck byte = 2 + frameTypeHello byte = 3 ) var ( @@ -86,6 +87,7 @@ type streamTransport struct { nextSeq atomic.Uint32 closed atomic.Bool writerUp atomic.Bool + peerReady atomic.Bool sendMu sync.Mutex startWriter sync.Once ackMu sync.Mutex @@ -286,7 +288,7 @@ func (p *streamTransport) WatchConnection(ctx context.Context) { // CanSend reports whether transport is ready for sending. func (p *streamTransport) CanSend() bool { - return !p.closed.Load() && p.stream.CanSend() + return !p.closed.Load() && p.peerReady.Load() && p.stream.CanSend() } // Features describes the current seichannel transport semantics. @@ -333,7 +335,7 @@ func (p *streamTransport) writerLoop() { ticker := time.NewTicker(p.effectiveFrameInterval()) defer ticker.Stop() - idle := buildVideoAccessUnit(nil) + idle := buildVideoAccessUnit(encodeHelloFrame()) for { select { @@ -443,9 +445,13 @@ func (p *streamTransport) handleSample(sample []byte) { } switch frame.typ { + case frameTypeHello: + p.peerReady.Store(true) case frameTypeAck: + p.peerReady.Store(true) p.resolveAck(frame.seq, frame.crc) case frameTypeData: + p.peerReady.Store(true) p.handleInboundFrame(frame) } } @@ -562,8 +568,8 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload out[5] = frameTypeData binary.BigEndian.PutUint32(out[6:10], seq) binary.BigEndian.PutUint32(out[10:14], crc) - binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic - binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic + binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic + binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic copy(out[22:], payload) return out @@ -579,6 +585,14 @@ func encodeAckFrame(seq, crc uint32) []byte { return out } +func encodeHelloFrame() []byte { + out := make([]byte, 6) + binary.BigEndian.PutUint32(out[0:4], protocolMagic) + out[4] = protocolVersion + out[5] = frameTypeHello + return out +} + func decodeTransportFrame(data []byte) (transportFrame, error) { if len(data) < 6 { return transportFrame{}, ErrFrameTooShort @@ -592,6 +606,8 @@ func decodeTransportFrame(data []byte) (transportFrame, error) { frame := transportFrame{typ: data[5]} switch frame.typ { + case frameTypeHello: + return frame, nil case frameTypeAck: if len(data) < 14 { return transportFrame{}, ErrAckTooShort diff --git a/internal/transport/seichannel/transport_test.go b/internal/transport/seichannel/transport_test.go index 8f11c6f..51c8272 100644 --- a/internal/transport/seichannel/transport_test.go +++ b/internal/transport/seichannel/transport_test.go @@ -78,3 +78,13 @@ func TestTransportFrameRoundTrip(t *testing.T) { t.Fatalf("payload mismatch: got=%q", decoded.payload) } } + +func TestHelloFrameRoundTrip(t *testing.T) { + hello, err := decodeTransportFrame(encodeHelloFrame()) + if err != nil { + t.Fatalf("decodeTransportFrame(hello) failed: %v", err) + } + if hello.typ != frameTypeHello { + t.Fatalf("hello frame type = %d, want %d", hello.typ, frameTypeHello) + } +} diff --git a/internal/transport/seichannel/transport_unit_test.go b/internal/transport/seichannel/transport_unit_test.go index 00abf58..716b970 100644 --- a/internal/transport/seichannel/transport_unit_test.go +++ b/internal/transport/seichannel/transport_unit_test.go @@ -103,8 +103,12 @@ func TestNewConnectCallbacksAndFeatures(t *testing.T) { if stream.reconnect == nil || stream.should == nil || stream.ended == nil || !stream.watched { t.Fatal("callbacks/watch were not forwarded") } + if tr.CanSend() { + t.Fatal("CanSend() = true before peer hello") + } + tr.handleSample(buildVideoAccessUnit(encodeHelloFrame())) if !tr.CanSend() { - t.Fatal("CanSend() = false, want true") + t.Fatal("CanSend() = false after peer hello") } if features := tr.Features(); !features.Reliable || !features.Ordered || !features.MessageOriented || features.MaxPayloadSize == 0 { //nolint:lll // long test description t.Fatalf("Features() = %+v", features)