Files
olcrtc/internal/e2e/tunnel_test.go
2026-05-07 21:52:14 +03:00

1201 lines
28 KiB
Go

//nolint:all // Test file keeps scenario setup inline.
package e2e
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"errors"
"flag"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/app/session"
"github.com/openlibrecommunity/olcrtc/internal/carrier"
"github.com/openlibrecommunity/olcrtc/internal/client"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/provider/jazz"
"github.com/openlibrecommunity/olcrtc/internal/provider/wbstream"
"github.com/openlibrecommunity/olcrtc/internal/server"
"github.com/pion/webrtc/v4"
)
const testKeyHex = "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff"
var (
realE2E = flag.Bool(
"olcrtc.real-e2e",
false,
"run real provider e2e matrix against external WebRTC services",
)
realE2ECarriers = flag.String(
"olcrtc.real-carriers",
"telemost,wbstream",
"comma-separated carriers for real e2e",
)
realE2ETransports = flag.String(
"olcrtc.real-transports",
"datachannel,videochannel,seichannel,vp8channel",
"comma-separated transports for real e2e",
)
realE2EJazzRoom = flag.String(
"olcrtc.real-jazz-room",
"",
"SaluteJazz room for real e2e, format roomID:password; autogenerated when empty",
)
realE2ETelemostRoom = flag.String(
"olcrtc.real-telemost-room",
"41514917109506",
"Telemost room URL or id for real e2e",
)
realE2EWBStreamRoom = flag.String(
"olcrtc.real-wbstream-room",
"",
"WB Stream room id for real e2e; autogenerated when empty",
)
realE2ETimeout = flag.Duration(
"olcrtc.real-timeout",
90*time.Second,
"timeout per real e2e provider/transport case",
)
)
type memorySession struct {
stream *memoryStream
}
func (s *memorySession) Capabilities() carrier.Capabilities {
return carrier.Capabilities{ByteStream: true, VideoTrack: true}
}
func (s *memorySession) OpenByteStream() (carrier.ByteStream, error) {
return s.stream, nil
}
func (s *memorySession) OpenVideoTrack() (carrier.VideoTrack, error) {
return s.stream, nil
}
type memoryRoom struct {
mu sync.Mutex
streams map[*memoryStream]struct{}
}
func (r *memoryRoom) connectedCount() int {
r.mu.Lock()
defer r.mu.Unlock()
count := 0
for stream := range r.streams {
if stream.isConnected() {
count++
}
}
return count
}
func (r *memoryRoom) waitConnected(t *testing.T, want int) {
t.Helper()
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
if r.connectedCount() >= want {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("memory room connected streams = %d, want at least %d", r.connectedCount(), want)
}
func (r *memoryRoom) triggerReconnect() {
r.mu.Lock()
streams := make([]*memoryStream, 0, len(r.streams))
for stream := range r.streams {
streams = append(streams, stream)
}
r.mu.Unlock()
for _, stream := range streams {
stream.triggerReconnect()
}
}
func (r *memoryRoom) triggerEnded(reason string) {
r.mu.Lock()
streams := make([]*memoryStream, 0, len(r.streams))
for stream := range r.streams {
streams = append(streams, stream)
}
r.mu.Unlock()
for _, stream := range streams {
stream.triggerEnded(reason)
}
}
type memoryStream struct {
room *memoryRoom
onData func([]byte)
mu sync.Mutex
connected bool
closed bool
reconnect func()
ended func(string)
track webrtc.TrackLocal
trackCB func(*webrtc.TrackRemote, *webrtc.RTPReceiver)
pending [][]byte
}
func (s *memoryStream) Connect(context.Context) error {
s.mu.Lock()
s.connected = true
pending := s.pending
s.pending = nil
onData := s.onData
s.mu.Unlock()
for _, payload := range pending {
if onData != nil {
onData(payload)
}
}
return nil
}
func (s *memoryStream) Send(data []byte) error {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return io.ErrClosedPipe
}
s.mu.Unlock()
payload := append([]byte(nil), data...)
s.room.mu.Lock()
peers := make([]*memoryStream, 0, len(s.room.streams))
for peer := range s.room.streams {
if peer != s {
peers = append(peers, peer)
}
}
s.room.mu.Unlock()
for _, peer := range peers {
peer.deliver(payload)
}
return nil
}
func (s *memoryStream) deliver(data []byte) {
s.mu.Lock()
if !s.connected && !s.closed {
s.pending = append(s.pending, append([]byte(nil), data...))
s.mu.Unlock()
return
}
ready := !s.closed && s.onData != nil
onData := s.onData
s.mu.Unlock()
if ready {
onData(append([]byte(nil), data...))
}
}
func (s *memoryStream) Close() error {
s.mu.Lock()
s.closed = true
s.connected = false
s.mu.Unlock()
return nil
}
func (s *memoryStream) SetReconnectCallback(cb func()) {
s.mu.Lock()
s.reconnect = cb
s.mu.Unlock()
}
func (s *memoryStream) SetShouldReconnect(func() bool) {}
func (s *memoryStream) SetEndedCallback(cb func(string)) {
s.mu.Lock()
s.ended = cb
s.mu.Unlock()
}
func (s *memoryStream) WatchConnection(ctx context.Context) {
<-ctx.Done()
}
func (s *memoryStream) CanSend() bool {
return s.isConnected()
}
func (s *memoryStream) AddTrack(track webrtc.TrackLocal) error {
s.mu.Lock()
s.track = track
s.mu.Unlock()
return nil
}
func (s *memoryStream) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) {
s.mu.Lock()
s.trackCB = cb
s.mu.Unlock()
}
func (s *memoryStream) isConnected() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.connected && !s.closed
}
func (s *memoryStream) triggerReconnect() {
s.mu.Lock()
reconnect := s.reconnect
ready := s.connected && !s.closed && reconnect != nil
s.mu.Unlock()
if ready {
reconnect()
}
}
func (s *memoryStream) triggerEnded(reason string) {
s.mu.Lock()
ended := s.ended
ready := s.connected && !s.closed && ended != nil
s.mu.Unlock()
if ready {
ended(reason)
}
}
func registerMemoryCarrier(t *testing.T) (string, *memoryRoom) {
t.Helper()
session.RegisterDefaults()
name := "e2e-memory-" + t.Name()
room := &memoryRoom{streams: make(map[*memoryStream]struct{})}
carrier.Register(name, func(_ context.Context, cfg carrier.Config) (carrier.Session, error) {
stream := &memoryStream{room: room, onData: cfg.OnData}
room.mu.Lock()
room.streams[stream] = struct{}{}
room.mu.Unlock()
return &memorySession{stream: stream}, nil
})
return name, room
}
func registerMemoryCarrierAs(t *testing.T, name string) *memoryRoom {
t.Helper()
room := &memoryRoom{streams: make(map[*memoryStream]struct{})}
carrier.Register(name, func(_ context.Context, cfg carrier.Config) (carrier.Session, error) {
stream := &memoryStream{room: room, onData: cfg.OnData}
room.mu.Lock()
room.streams[stream] = struct{}{}
room.mu.Unlock()
return &memorySession{stream: stream}, nil
})
return room
}
func builtInCarrierNames() []string {
return []string{"jazz", "telemost", "wbstream"}
}
func builtInTransportNames() []string {
return []string{"datachannel", "videochannel", "seichannel", "vp8channel"}
}
func realE2EExpectedToPass(carrierName, transportName string) bool {
switch carrierName {
case "telemost":
return transportName == "videochannel" || transportName == "vp8channel"
case "wbstream":
return true
default:
return true
}
}
func realE2EExpectation(carrierName, transportName string) string {
if realE2EExpectedToPass(carrierName, transportName) {
return "SUCCESS"
}
return "EXPECTED FAIL"
}
func splitTestList(value string) []string {
parts := strings.Split(value, ",")
items := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part != "" {
items = append(items, part)
}
}
return items
}
func realRoomURL(ctx context.Context, t *testing.T, carrierName string) string {
t.Helper()
switch carrierName {
case "jazz":
if *realE2EJazzRoom != "" {
return *realE2EJazzRoom
}
room, err := jazz.CreateRoom(ctx)
if err != nil {
t.Fatalf("create real jazz room: %v", err)
}
return room.RoomID + ":" + room.Password
case "telemost":
room := *realE2ETelemostRoom
if room != "" && !strings.HasPrefix(room, "http://") && !strings.HasPrefix(room, "https://") {
room = "https://telemost.yandex.ru/j/" + room
}
return room
case "wbstream":
if *realE2EWBStreamRoom != "" {
return *realE2EWBStreamRoom
}
room, err := wbstream.CreateRoom(ctx, "olcrtc-e2e-room")
if err != nil {
t.Fatalf("create real wbstream room: %v", err)
}
return room
default:
return ""
}
}
func requireRealRoom(ctx context.Context, t *testing.T, carrierName string) string {
t.Helper()
roomURL := realRoomURL(ctx, t, carrierName)
if roomURL == "" {
t.Fatalf("missing room for %s", carrierName)
}
return roomURL
}
func validSessionConfig(mode, carrierName, transportName string) session.Config {
return session.Config{
Mode: mode,
Link: "direct",
Transport: transportName,
Carrier: carrierName,
RoomID: "room",
ClientID: "client-1",
KeyHex: testKeyHex,
SOCKSHost: "127.0.0.1",
SOCKSPort: 1080,
DNSServer: "127.0.0.1:53",
VideoWidth: 1080,
VideoHeight: 1080,
VideoFPS: 30,
VideoBitrate: "1M",
VideoHW: "none",
VideoCodec: "tile",
VideoTileModule: 4,
VideoTileRS: 20,
VP8FPS: 60,
VP8BatchSize: 8,
SEIFPS: 30,
SEIBatchSize: 4,
SEIFragmentSize: 512,
SEIAckTimeoutMS: 1500,
}
}
func validLinkConfig(carrierName, transportName string) link.Config {
cfg := validSessionConfig("cnc", carrierName, transportName)
return link.Config{
Transport: cfg.Transport,
Carrier: cfg.Carrier,
RoomURL: "room",
ClientID: cfg.ClientID,
Name: "e2e-" + carrierName + "-" + transportName,
DNSServer: cfg.DNSServer,
VideoWidth: cfg.VideoWidth,
VideoHeight: cfg.VideoHeight,
VideoFPS: cfg.VideoFPS,
VideoBitrate: cfg.VideoBitrate,
VideoHW: cfg.VideoHW,
VideoCodec: cfg.VideoCodec,
VideoTileModule: cfg.VideoTileModule,
VideoTileRS: cfg.VideoTileRS,
VP8FPS: cfg.VP8FPS,
VP8BatchSize: cfg.VP8BatchSize,
SEIFPS: cfg.SEIFPS,
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
}
}
func startEchoServer(t *testing.T) string {
t.Helper()
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen echo: %v", err)
}
t.Cleanup(func() { _ = ln.Close() })
go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func() {
defer func() { _ = conn.Close() }()
_, _ = io.Copy(conn, conn)
}()
}
}()
return ln.Addr().String()
}
func freeLocalAddr(t *testing.T) string {
t.Helper()
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("reserve local addr: %v", err)
}
addr := ln.Addr().String()
if err := ln.Close(); err != nil {
t.Fatalf("close reserved addr: %v", err)
}
return addr
}
func waitForReady(t *testing.T, ready <-chan struct{}) {
t.Helper()
select {
case <-ready:
case <-time.After(3 * time.Second):
t.Fatal("client did not become ready")
}
}
type tunnelRuntime struct {
socksAddr string
room *memoryRoom
cancel context.CancelFunc
serverErr chan error
clientErr chan error
}
func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRuntime {
t.Helper()
carrierName, room := registerMemoryCarrier(t)
socksAddr := freeLocalAddr(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
serverErr := make(chan error, 1)
go func() {
serverErr <- server.Run(
ctx,
"direct",
"datachannel",
carrierName,
"room",
testKeyHex,
serverClientID,
"127.0.0.1:53",
"",
0,
0,
0,
0,
"",
"",
0,
"",
"",
0,
0,
0,
0,
0,
0,
0,
0,
)
}()
room.waitConnected(t, 1)
ready := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
clientErr <- client.RunWithReady(
ctx,
"direct",
"datachannel",
carrierName,
"room",
testKeyHex,
clientClientID,
socksAddr,
"127.0.0.1:53",
"",
"",
func() { close(ready) },
0,
0,
0,
"",
"",
0,
"",
"",
0,
0,
0,
0,
0,
0,
0,
0,
)
}()
waitForReady(t, ready)
return &tunnelRuntime{
socksAddr: socksAddr,
room: room,
cancel: cancel,
serverErr: serverErr,
clientErr: clientErr,
}
}
func startRealTunnel(
t *testing.T,
ctx context.Context,
carrierName, transportName, roomURL, serverClientID, clientClientID string,
) (*tunnelRuntime, error) {
t.Helper()
session.RegisterDefaults()
socksAddr := freeLocalAddr(t)
runCtx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
serverErr := make(chan error, 1)
go func() {
serverErr <- server.Run(
runCtx,
"direct",
transportName,
carrierName,
roomURL,
testKeyHex,
serverClientID,
"127.0.0.1:53",
"",
0,
1080,
1080,
60,
"5000k",
"none",
512,
"low",
"qrcode",
4,
20,
60,
8,
30,
4,
512,
1500,
)
}()
select {
case err := <-serverErr:
cancel()
return nil, fmt.Errorf("server exited before client start: %w", err)
case <-time.After(2 * time.Second):
case <-runCtx.Done():
cancel()
return nil, fmt.Errorf("server context ended before client start: %w", runCtx.Err())
}
ready := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
clientErr <- client.RunWithReady(
runCtx,
"direct",
transportName,
carrierName,
roomURL,
testKeyHex,
clientClientID,
socksAddr,
"127.0.0.1:53",
"",
"",
func() { close(ready) },
1080,
1080,
60,
"5000k",
"none",
512,
"low",
"qrcode",
4,
20,
60,
8,
30,
4,
512,
1500,
)
}()
select {
case <-ready:
case err := <-clientErr:
cancel()
return nil, fmt.Errorf("client exited before ready: %w", err)
case err := <-serverErr:
cancel()
return nil, fmt.Errorf("server exited before client ready: %w", err)
case <-time.After(*realE2ETimeout):
cancel()
return nil, errors.New("real e2e client did not become ready")
case <-runCtx.Done():
cancel()
return nil, fmt.Errorf("real e2e context ended before ready: %w", runCtx.Err())
}
return &tunnelRuntime{
socksAddr: socksAddr,
cancel: cancel,
serverErr: serverErr,
clientErr: clientErr,
}, nil
}
func (r *tunnelRuntime) stop(t *testing.T) {
t.Helper()
if err := r.stopErr(); err != nil {
t.Fatal(err)
}
}
func (r *tunnelRuntime) waitStopped(t *testing.T) {
t.Helper()
if err := r.waitStoppedErr(); err != nil {
t.Fatal(err)
}
}
func (r *tunnelRuntime) stopErr() error {
r.cancel()
return r.waitStoppedErr()
}
func (r *tunnelRuntime) waitStoppedErr() error {
for name, ch := range map[string]<-chan error{"client": r.clientErr, "server": r.serverErr} {
select {
case err := <-ch:
if err != nil {
return fmt.Errorf("%s returned error: %w", name, err)
}
case <-time.After(3 * time.Second):
return fmt.Errorf("%s did not stop", name)
}
}
return nil
}
func connectViaSOCKS(t *testing.T, socksAddr, targetAddr string) net.Conn {
t.Helper()
conn, err := net.DialTimeout("tcp4", socksAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial socks: %v", err)
}
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
_ = conn.Close()
t.Fatalf("write socks greeting: %v", err)
}
greeting := make([]byte, 2)
if _, err := io.ReadFull(conn, greeting); err != nil {
_ = conn.Close()
t.Fatalf("read socks greeting: %v", err)
}
if !bytes.Equal(greeting, []byte{5, 0}) {
_ = conn.Close()
t.Fatalf("socks greeting = %v, want [5 0]", greeting)
}
host, portText, err := net.SplitHostPort(targetAddr)
if err != nil {
_ = conn.Close()
t.Fatalf("split target addr: %v", err)
}
port, err := strconv.Atoi(portText)
if err != nil {
_ = conn.Close()
t.Fatalf("parse target port: %v", err)
}
req := []byte{5, 1, 0, 1}
req = append(req, net.ParseIP(host).To4()...)
var portBuf [2]byte
binary.BigEndian.PutUint16(portBuf[:], uint16(port))
req = append(req, portBuf[:]...)
if _, err := conn.Write(req); err != nil {
_ = conn.Close()
t.Fatalf("write socks connect: %v", err)
}
reply := make([]byte, 10)
if _, err := io.ReadFull(conn, reply); err != nil {
_ = conn.Close()
t.Fatalf("read socks connect reply: %v", err)
}
if !bytes.Equal(reply, []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) {
_ = conn.Close()
t.Fatalf("socks reply = %v, want success", reply)
}
return conn
}
func connectViaSOCKSExpectFailure(t *testing.T, socksAddr, targetAddr string) []byte {
t.Helper()
conn, err := net.DialTimeout("tcp4", socksAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial socks: %v", err)
}
defer func() { _ = conn.Close() }()
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
t.Fatalf("write socks greeting: %v", err)
}
greeting := make([]byte, 2)
if _, err := io.ReadFull(conn, greeting); err != nil {
t.Fatalf("read socks greeting: %v", err)
}
host, portText, err := net.SplitHostPort(targetAddr)
if err != nil {
t.Fatalf("split target addr: %v", err)
}
port, err := strconv.Atoi(portText)
if err != nil {
t.Fatalf("parse target port: %v", err)
}
req := []byte{5, 1, 0, 1}
req = append(req, net.ParseIP(host).To4()...)
var portBuf [2]byte
binary.BigEndian.PutUint16(portBuf[:], uint16(port))
req = append(req, portBuf[:]...)
if _, err := conn.Write(req); err != nil {
t.Fatalf("write socks connect: %v", err)
}
reply := make([]byte, 10)
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read socks failure reply: %v", err)
}
return reply
}
func TestBuiltInProviderTransportMatrixValidates(t *testing.T) {
session.RegisterDefaults()
for _, mode := range []string{"srv", "cnc"} {
t.Run(mode, func(t *testing.T) {
for _, carrierName := range builtInCarrierNames() {
t.Run(carrierName, func(t *testing.T) {
for _, transportName := range builtInTransportNames() {
t.Run(transportName, func(t *testing.T) {
cfg := validSessionConfig(mode, carrierName, transportName)
if err := session.Validate(cfg); err != nil {
t.Fatalf("Validate() error = %v", err)
}
})
}
})
}
})
}
}
func TestDirectLinkCreatesAllProviderTransportCombinations(t *testing.T) {
session.RegisterDefaults()
for _, carrierName := range builtInCarrierNames() {
registerMemoryCarrierAs(t, carrierName)
}
for _, carrierName := range builtInCarrierNames() {
t.Run(carrierName, func(t *testing.T) {
for _, transportName := range builtInTransportNames() {
t.Run(transportName, func(t *testing.T) {
ln, err := link.New(context.Background(), "direct", validLinkConfig(carrierName, transportName))
if err != nil {
t.Fatalf("link.New() error = %v", err)
}
if err := ln.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
})
}
})
}
}
func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) {
session.RegisterDefaults()
for _, carrierName := range builtInCarrierNames() {
registerMemoryCarrierAs(t, carrierName)
}
for _, carrierName := range builtInCarrierNames() {
t.Run(carrierName, func(t *testing.T) {
for _, transportName := range []string{"datachannel", "seichannel"} {
t.Run(transportName, func(t *testing.T) {
ln, err := link.New(context.Background(), "direct", validLinkConfig(carrierName, transportName))
if err != nil {
t.Fatalf("link.New() error = %v", err)
}
if err := ln.Connect(context.Background()); err != nil {
t.Fatalf("Connect() error = %v", err)
}
if !ln.CanSend() {
t.Fatal("CanSend() = false, want true")
}
if err := ln.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
})
}
})
}
}
func TestRealProviderTransportMatrix(t *testing.T) {
if !*realE2E {
t.Skip("real provider e2e disabled; pass -olcrtc.real-e2e with provider room flags")
}
carriers := splitTestList(*realE2ECarriers)
transports := splitTestList(*realE2ETransports)
if len(carriers) == 0 {
t.Fatal("no real e2e carriers selected")
}
if len(transports) == 0 {
t.Fatal("no real e2e transports selected")
}
echoAddr := startEchoServer(t)
for _, carrierName := range carriers {
t.Run(carrierName, func(t *testing.T) {
roomCtx, cancelRoom := context.WithTimeout(context.Background(), *realE2ETimeout)
defer cancelRoom()
roomURL := requireRealRoom(roomCtx, t, carrierName)
for _, transportName := range transports {
t.Run(transportName, func(t *testing.T) {
expectPass := realE2EExpectedToPass(carrierName, transportName)
err := runRealE2ECase(t, carrierName, transportName, roomURL, echoAddr)
switch {
case err == nil && expectPass:
t.Logf("%s %s/%s", realE2EExpectation(carrierName, transportName), carrierName, transportName)
case err == nil && !expectPass:
t.Fatalf("UNEXPECTED SUCCESS %s/%s", carrierName, transportName)
case err != nil && expectPass:
t.Fatalf("EXPECTED SUCCESS %s/%s failed: %v", carrierName, transportName, err)
case err != nil && !expectPass:
t.Logf("%s %s/%s: %v", realE2EExpectation(carrierName, transportName), carrierName, transportName, err)
}
})
}
})
}
}
func runRealE2ECase(t *testing.T, carrierName, transportName, roomURL, echoAddr string) (err error) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *realE2ETimeout)
defer cancel()
rt, err := startRealTunnel(t, ctx, carrierName, transportName, roomURL, "client-1", "client-1")
if err != nil {
return err
}
defer func() {
if stopErr := rt.stopErr(); err == nil && stopErr != nil {
err = stopErr
}
}()
conn, err := connectViaSOCKSWithin(rt.socksAddr, echoAddr, *realE2ETimeout)
if err != nil {
return err
}
defer func() { _ = conn.Close() }()
payload := []byte("olcrtc-real-e2e-" + carrierName + "-" + transportName + "\n")
if _, err := conn.Write(payload); err != nil {
return fmt.Errorf("write real e2e payload: %w", err)
}
if err := conn.SetReadDeadline(time.Now().Add(*realE2ETimeout)); err != nil {
return fmt.Errorf("set real e2e read deadline: %w", err)
}
line, err := bufio.NewReader(conn).ReadBytes('\n')
if err != nil {
return fmt.Errorf("read real e2e echo: %w", err)
}
if !bytes.Equal(line, payload) {
return fmt.Errorf("real e2e echo = %q, want %q", line, payload)
}
return nil
}
func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) {
echoAddr := startEchoServer(t)
rt := startTunnel(t, "client-1", "client-1")
defer rt.stop(t)
conn := connectViaSOCKS(t, rt.socksAddr, echoAddr)
defer func() { _ = conn.Close() }()
payload := []byte("olcrtc-e2e-payload\n")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write tunneled payload: %v", err)
}
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
t.Fatalf("set read deadline: %v", err)
}
line, err := bufio.NewReader(conn).ReadBytes('\n')
if err != nil {
t.Fatalf("read tunneled echo: %v", err)
}
if !bytes.Equal(line, payload) {
t.Fatalf("echo = %q, want %q", line, payload)
}
}
func TestWrongClientIDIsRejected(t *testing.T) {
echoAddr := startEchoServer(t)
rt := startTunnel(t, "server-client", "wrong-client")
defer rt.stop(t)
reply := connectViaSOCKSExpectFailure(t, rt.socksAddr, echoAddr)
if !bytes.Equal(reply, []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}) {
t.Fatalf("wrong client-id reply = %v, want host unreachable", reply)
}
}
func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) {
echoAddr := startEchoServer(t)
rt := startTunnel(t, "client-1", "client-1")
defer rt.stop(t)
for i := range 5 {
rt.room.triggerReconnect()
conn := eventuallyConnectViaSOCKS(t, rt.socksAddr, echoAddr)
payload := fmt.Appendf(nil, "after-reconnect-%d\n", i)
if _, err := conn.Write(payload); err != nil {
_ = conn.Close()
t.Fatalf("write after reconnect %d: %v", i, err)
}
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
_ = conn.Close()
t.Fatalf("set deadline after reconnect %d: %v", i, err)
}
line, err := bufio.NewReader(conn).ReadBytes('\n')
_ = conn.Close()
if err != nil {
t.Fatalf("read after reconnect %d: %v", i, err)
}
if !bytes.Equal(line, payload) {
t.Fatalf("echo after reconnect %d = %q, want %q", i, line, payload)
}
}
}
func TestEndedCallbackStopsClientAndServer(t *testing.T) {
rt := startTunnel(t, "client-1", "client-1")
rt.room.triggerEnded("conference ended")
rt.waitStopped(t)
}
func eventuallyConnectViaSOCKS(t *testing.T, socksAddr, targetAddr string) net.Conn {
t.Helper()
return eventuallyConnectViaSOCKSWithin(t, socksAddr, targetAddr, 3*time.Second)
}
func eventuallyConnectViaSOCKSWithin(t *testing.T, socksAddr, targetAddr string, timeout time.Duration) net.Conn {
t.Helper()
conn, err := connectViaSOCKSWithin(socksAddr, targetAddr, timeout)
if err != nil {
t.Fatal(err)
}
return conn
}
func connectViaSOCKSWithin(socksAddr, targetAddr string, timeout time.Duration) (net.Conn, error) {
deadline := time.Now().Add(timeout)
var lastErr error
attempt := 0
for time.Now().Before(deadline) {
conn, err := tryConnectViaSOCKS(socksAddr, targetAddr)
if err == nil {
return conn, nil
}
lastErr = err
attempt++
sleep := 250 * time.Millisecond
if attempt > 3 {
sleep = time.Second
}
time.Sleep(sleep)
}
return nil, fmt.Errorf("connect via SOCKS failed after %s: %w", timeout, lastErr)
}
func tryConnectViaSOCKS(socksAddr, targetAddr string) (net.Conn, error) {
conn, err := net.DialTimeout("tcp4", socksAddr, 500*time.Millisecond)
if err != nil {
return nil, err
}
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
_ = conn.Close()
return nil, err
}
greeting := make([]byte, 2)
if _, err := io.ReadFull(conn, greeting); err != nil {
_ = conn.Close()
return nil, err
}
if !bytes.Equal(greeting, []byte{5, 0}) {
_ = conn.Close()
return nil, fmt.Errorf("unexpected greeting: %v", greeting)
}
host, portText, err := net.SplitHostPort(targetAddr)
if err != nil {
_ = conn.Close()
return nil, err
}
port, err := strconv.Atoi(portText)
if err != nil {
_ = conn.Close()
return nil, err
}
req := []byte{5, 1, 0, 1}
req = append(req, net.ParseIP(host).To4()...)
var portBuf [2]byte
binary.BigEndian.PutUint16(portBuf[:], uint16(port))
req = append(req, portBuf[:]...)
if _, err := conn.Write(req); err != nil {
_ = conn.Close()
return nil, err
}
reply := make([]byte, 10)
if _, err := io.ReadFull(conn, reply); err != nil {
_ = conn.Close()
return nil, err
}
if !bytes.Equal(reply, []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) {
_ = conn.Close()
return nil, fmt.Errorf("unexpected reply: %v", reply)
}
return conn, nil
}
func TestLargeTransferOverTunnel(t *testing.T) {
echoAddr := startEchoServer(t)
rt := startTunnel(t, "client-1", "client-1")
defer rt.stop(t)
size := int64(32 << 20)
conn := connectViaSOCKS(t, rt.socksAddr, echoAddr)
defer func() { _ = conn.Close() }()
if err := streamPatternAndVerifyEcho(conn, size); err != nil {
t.Fatalf("large transfer %d bytes failed: %v", size, err)
}
}
func streamPatternAndVerifyEcho(conn net.Conn, size int64) error {
errCh := make(chan error, 1)
go func() {
buf := make([]byte, 32*1024)
var written int64
for written < size {
n := len(buf)
if remaining := size - written; remaining < int64(n) {
n = int(remaining)
}
fillPattern(buf[:n], written)
if _, err := conn.Write(buf[:n]); err != nil {
errCh <- fmt.Errorf("write at %d: %w", written, err)
return
}
written += int64(n)
}
errCh <- nil
}()
buf := make([]byte, 32*1024)
want := make([]byte, len(buf))
var read int64
for read < size {
n := len(buf)
if remaining := size - read; remaining < int64(n) {
n = int(remaining)
}
if err := conn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
return err
}
if _, err := io.ReadFull(conn, buf[:n]); err != nil {
return fmt.Errorf("read at %d: %w", read, err)
}
fillPattern(want[:n], read)
if !bytes.Equal(buf[:n], want[:n]) {
return fmt.Errorf("payload mismatch at offset %d", read)
}
read += int64(n)
}
if err := <-errCh; err != nil {
return err
}
return nil
}
func fillPattern(buf []byte, offset int64) {
for i := range buf {
buf[i] = byte((offset + int64(i)*31 + 7) & 0xff)
}
}