diff --git a/internal/e2e/tunnel_test.go b/internal/e2e/tunnel_test.go index 21a1ce5..a057256 100644 --- a/internal/e2e/tunnel_test.go +++ b/internal/e2e/tunnel_test.go @@ -8,6 +8,8 @@ import ( "fmt" "io" "net" + "os" + "strconv" "sync" "testing" "time" @@ -37,6 +39,58 @@ type memoryRoom struct { streams map[*memoryStream]struct{} } +func (r *memoryRoom) connectedCount() int { + r.mu.Lock() + defer r.mu.Unlock() + + count := 0 + for stream := range r.streams { + if stream.isConnected() { + count++ + } + } + return count +} + +func (r *memoryRoom) waitConnected(t *testing.T, want int) { + t.Helper() + + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + if r.connectedCount() >= want { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("memory room connected streams = %d, want at least %d", r.connectedCount(), want) +} + +func (r *memoryRoom) triggerReconnect() { + r.mu.Lock() + streams := make([]*memoryStream, 0, len(r.streams)) + for stream := range r.streams { + streams = append(streams, stream) + } + r.mu.Unlock() + + for _, stream := range streams { + stream.triggerReconnect() + } +} + +func (r *memoryRoom) triggerEnded(reason string) { + r.mu.Lock() + streams := make([]*memoryStream, 0, len(r.streams)) + for stream := range r.streams { + streams = append(streams, stream) + } + r.mu.Unlock() + + for _, stream := range streams { + stream.triggerEnded(reason) + } +} + type memoryStream struct { room *memoryRoom onData func([]byte) @@ -44,12 +98,23 @@ type memoryStream struct { mu sync.Mutex connected bool closed bool + reconnect func() + ended func(string) + pending [][]byte } func (s *memoryStream) Connect(context.Context) error { s.mu.Lock() s.connected = true + pending := s.pending + s.pending = nil + onData := s.onData s.mu.Unlock() + for _, payload := range pending { + if onData != nil { + onData(payload) + } + } return nil } @@ -79,7 +144,12 @@ func (s *memoryStream) Send(data []byte) error { func (s *memoryStream) deliver(data []byte) { s.mu.Lock() - ready := s.connected && !s.closed && s.onData != nil + if !s.connected && !s.closed { + s.pending = append(s.pending, append([]byte(nil), data...)) + s.mu.Unlock() + return + } + ready := !s.closed && s.onData != nil onData := s.onData s.mu.Unlock() if ready { @@ -95,19 +165,51 @@ func (s *memoryStream) Close() error { return nil } -func (s *memoryStream) SetReconnectCallback(func()) {} +func (s *memoryStream) SetReconnectCallback(cb func()) { + s.mu.Lock() + s.reconnect = cb + s.mu.Unlock() +} func (s *memoryStream) SetShouldReconnect(func() bool) {} -func (s *memoryStream) SetEndedCallback(func(string)) {} +func (s *memoryStream) SetEndedCallback(cb func(string)) { + s.mu.Lock() + s.ended = cb + s.mu.Unlock() +} func (s *memoryStream) WatchConnection(ctx context.Context) { <-ctx.Done() } func (s *memoryStream) CanSend() bool { + return s.isConnected() +} + +func (s *memoryStream) isConnected() bool { s.mu.Lock() defer s.mu.Unlock() return s.connected && !s.closed } -func registerMemoryCarrier(t *testing.T) string { +func (s *memoryStream) triggerReconnect() { + s.mu.Lock() + reconnect := s.reconnect + ready := s.connected && !s.closed && reconnect != nil + s.mu.Unlock() + if ready { + reconnect() + } +} + +func (s *memoryStream) triggerEnded(reason string) { + s.mu.Lock() + ended := s.ended + ready := s.connected && !s.closed && ended != nil + s.mu.Unlock() + if ready { + ended(reason) + } +} + +func registerMemoryCarrier(t *testing.T) (string, *memoryRoom) { t.Helper() session.RegisterDefaults() @@ -120,7 +222,7 @@ func registerMemoryCarrier(t *testing.T) string { room.mu.Unlock() return &memorySession{stream: stream}, nil }) - return name + return name, room } func startEchoServer(t *testing.T) string { @@ -170,9 +272,18 @@ func waitForReady(t *testing.T, ready <-chan struct{}) { } } -func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { - carrierName := registerMemoryCarrier(t) - echoAddr := startEchoServer(t) +type tunnelRuntime struct { + socksAddr string + room *memoryRoom + cancel context.CancelFunc + serverErr chan error + clientErr chan error +} + +func startTunnel(t *testing.T, serverClientID, clientClientID string) *tunnelRuntime { + t.Helper() + + carrierName, room := registerMemoryCarrier(t) socksAddr := freeLocalAddr(t) ctx, cancel := context.WithCancel(context.Background()) @@ -187,7 +298,7 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { carrierName, "room", testKeyHex, - "client-1", + serverClientID, "127.0.0.1:53", "", 0, @@ -205,6 +316,7 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { 0, ) }() + room.waitConnected(t, 1) ready := make(chan struct{}) clientErr := make(chan error, 1) @@ -216,7 +328,7 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { carrierName, "room", testKeyHex, - "client-1", + clientClientID, socksAddr, "127.0.0.1:53", "", @@ -238,6 +350,93 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { }() waitForReady(t, ready) + return &tunnelRuntime{ + socksAddr: socksAddr, + room: room, + cancel: cancel, + serverErr: serverErr, + clientErr: clientErr, + } +} + +func (r *tunnelRuntime) stop(t *testing.T) { + t.Helper() + r.cancel() + r.waitStopped(t) +} + +func (r *tunnelRuntime) waitStopped(t *testing.T) { + t.Helper() + for name, ch := range map[string]<-chan error{"client": r.clientErr, "server": r.serverErr} { + select { + case err := <-ch: + if err != nil { + t.Fatalf("%s returned error: %v", name, err) + } + case <-time.After(3 * time.Second): + t.Fatalf("%s did not stop", name) + } + } +} + +func connectViaSOCKS(t *testing.T, socksAddr, targetAddr string) net.Conn { + t.Helper() + + conn, err := net.DialTimeout("tcp4", socksAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial socks: %v", err) + } + + if _, err := conn.Write([]byte{5, 1, 0}); err != nil { + _ = conn.Close() + t.Fatalf("write socks greeting: %v", err) + } + greeting := make([]byte, 2) + if _, err := io.ReadFull(conn, greeting); err != nil { + _ = conn.Close() + t.Fatalf("read socks greeting: %v", err) + } + if !bytes.Equal(greeting, []byte{5, 0}) { + _ = conn.Close() + t.Fatalf("socks greeting = %v, want [5 0]", greeting) + } + + host, portText, err := net.SplitHostPort(targetAddr) + if err != nil { + _ = conn.Close() + t.Fatalf("split target addr: %v", err) + } + port, err := strconv.Atoi(portText) + if err != nil { + _ = conn.Close() + t.Fatalf("parse target port: %v", err) + } + req := []byte{5, 1, 0, 1} + req = append(req, net.ParseIP(host).To4()...) + var portBuf [2]byte + binary.BigEndian.PutUint16(portBuf[:], uint16(port)) + req = append(req, portBuf[:]...) + if _, err := conn.Write(req); err != nil { + _ = conn.Close() + t.Fatalf("write socks connect: %v", err) + } + + reply := make([]byte, 10) + if _, err := io.ReadFull(conn, reply); err != nil { + _ = conn.Close() + t.Fatalf("read socks connect reply: %v", err) + } + if !bytes.Equal(reply, []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) { + _ = conn.Close() + t.Fatalf("socks reply = %v, want success", reply) + } + + return conn +} + +func connectViaSOCKSExpectFailure(t *testing.T, socksAddr, targetAddr string) []byte { + t.Helper() + conn, err := net.DialTimeout("tcp4", socksAddr, 2*time.Second) if err != nil { t.Fatalf("dial socks: %v", err) @@ -251,17 +450,14 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { if _, err := io.ReadFull(conn, greeting); err != nil { t.Fatalf("read socks greeting: %v", err) } - if !bytes.Equal(greeting, []byte{5, 0}) { - t.Fatalf("socks greeting = %v, want [5 0]", greeting) - } - host, portText, err := net.SplitHostPort(echoAddr) + host, portText, err := net.SplitHostPort(targetAddr) if err != nil { - t.Fatalf("split echo addr: %v", err) + t.Fatalf("split target addr: %v", err) } - var port int - if _, err := fmt.Sscanf(portText, "%d", &port); err != nil { - t.Fatalf("parse echo port: %v", err) + port, err := strconv.Atoi(portText) + if err != nil { + t.Fatalf("parse target port: %v", err) } req := []byte{5, 1, 0, 1} req = append(req, net.ParseIP(host).To4()...) @@ -274,11 +470,18 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { reply := make([]byte, 10) if _, err := io.ReadFull(conn, reply); err != nil { - t.Fatalf("read socks connect reply: %v", err) - } - if !bytes.Equal(reply, []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) { - t.Fatalf("socks reply = %v, want success", reply) + t.Fatalf("read socks failure reply: %v", err) } + return reply +} + +func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { + echoAddr := startEchoServer(t) + rt := startTunnel(t, "client-1", "client-1") + defer rt.stop(t) + + conn := connectViaSOCKS(t, rt.socksAddr, echoAddr) + defer func() { _ = conn.Close() }() payload := []byte("olcrtc-e2e-payload\n") if _, err := conn.Write(payload); err != nil { @@ -294,16 +497,186 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { if !bytes.Equal(line, payload) { t.Fatalf("echo = %q, want %q", line, payload) } +} - cancel() - for name, ch := range map[string]<-chan error{"client": clientErr, "server": serverErr} { - select { - case err := <-ch: - if err != nil { - t.Fatalf("%s returned error: %v", name, err) - } - case <-time.After(3 * time.Second): - t.Fatalf("%s did not stop", name) +func TestWrongClientIDIsRejected(t *testing.T) { + echoAddr := startEchoServer(t) + rt := startTunnel(t, "server-client", "wrong-client") + defer rt.stop(t) + + reply := connectViaSOCKSExpectFailure(t, rt.socksAddr, echoAddr) + if !bytes.Equal(reply, []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}) { + t.Fatalf("wrong client-id reply = %v, want host unreachable", reply) + } +} + +func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) { + echoAddr := startEchoServer(t) + rt := startTunnel(t, "client-1", "client-1") + defer rt.stop(t) + + for i := range 5 { + rt.room.triggerReconnect() + conn := eventuallyConnectViaSOCKS(t, rt.socksAddr, echoAddr) + payload := []byte(fmt.Sprintf("after-reconnect-%d\n", i)) + if _, err := conn.Write(payload); err != nil { + _ = conn.Close() + t.Fatalf("write after reconnect %d: %v", i, err) + } + if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { + _ = conn.Close() + t.Fatalf("set deadline after reconnect %d: %v", i, err) + } + line, err := bufio.NewReader(conn).ReadBytes('\n') + _ = conn.Close() + if err != nil { + t.Fatalf("read after reconnect %d: %v", i, err) + } + if !bytes.Equal(line, payload) { + t.Fatalf("echo after reconnect %d = %q, want %q", i, line, payload) } } } + +func TestEndedCallbackStopsClientAndServer(t *testing.T) { + rt := startTunnel(t, "client-1", "client-1") + rt.room.triggerEnded("conference ended") + rt.waitStopped(t) +} + +func eventuallyConnectViaSOCKS(t *testing.T, socksAddr, targetAddr string) net.Conn { + t.Helper() + + deadline := time.Now().Add(3 * time.Second) + var lastErr error + for time.Now().Before(deadline) { + conn, err := tryConnectViaSOCKS(socksAddr, targetAddr) + if err == nil { + return conn + } + lastErr = err + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("connect after reconnect failed: %v", lastErr) + return nil +} + +func tryConnectViaSOCKS(socksAddr, targetAddr string) (net.Conn, error) { + conn, err := net.DialTimeout("tcp4", socksAddr, 500*time.Millisecond) + if err != nil { + return nil, err + } + if _, err := conn.Write([]byte{5, 1, 0}); err != nil { + _ = conn.Close() + return nil, err + } + greeting := make([]byte, 2) + if _, err := io.ReadFull(conn, greeting); err != nil { + _ = conn.Close() + return nil, err + } + if !bytes.Equal(greeting, []byte{5, 0}) { + _ = conn.Close() + return nil, fmt.Errorf("unexpected greeting: %v", greeting) + } + + host, portText, err := net.SplitHostPort(targetAddr) + if err != nil { + _ = conn.Close() + return nil, err + } + port, err := strconv.Atoi(portText) + if err != nil { + _ = conn.Close() + return nil, err + } + req := []byte{5, 1, 0, 1} + req = append(req, net.ParseIP(host).To4()...) + var portBuf [2]byte + binary.BigEndian.PutUint16(portBuf[:], uint16(port)) + req = append(req, portBuf[:]...) + if _, err := conn.Write(req); err != nil { + _ = conn.Close() + return nil, err + } + reply := make([]byte, 10) + if _, err := io.ReadFull(conn, reply); err != nil { + _ = conn.Close() + return nil, err + } + if !bytes.Equal(reply, []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) { + _ = conn.Close() + return nil, fmt.Errorf("unexpected reply: %v", reply) + } + return conn, nil +} + +func TestLargeTransferOverTunnel(t *testing.T) { + echoAddr := startEchoServer(t) + rt := startTunnel(t, "client-1", "client-1") + defer rt.stop(t) + + size := int64(32 << 20) + if os.Getenv("OLCRTC_E2E_10GB") == "1" { + size = 10 * 1024 * 1024 * 1024 + } + + conn := connectViaSOCKS(t, rt.socksAddr, echoAddr) + defer func() { _ = conn.Close() }() + + if err := streamPatternAndVerifyEcho(conn, size); err != nil { + t.Fatalf("large transfer %d bytes failed: %v", size, err) + } +} + +func streamPatternAndVerifyEcho(conn net.Conn, size int64) error { + errCh := make(chan error, 1) + go func() { + buf := make([]byte, 32*1024) + var written int64 + for written < size { + n := len(buf) + if remaining := size - written; remaining < int64(n) { + n = int(remaining) + } + fillPattern(buf[:n], written) + if _, err := conn.Write(buf[:n]); err != nil { + errCh <- fmt.Errorf("write at %d: %w", written, err) + return + } + written += int64(n) + } + errCh <- nil + }() + + buf := make([]byte, 32*1024) + want := make([]byte, len(buf)) + var read int64 + for read < size { + n := len(buf) + if remaining := size - read; remaining < int64(n) { + n = int(remaining) + } + if err := conn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil { + return err + } + if _, err := io.ReadFull(conn, buf[:n]); err != nil { + return fmt.Errorf("read at %d: %w", read, err) + } + fillPattern(want[:n], read) + if !bytes.Equal(buf[:n], want[:n]) { + return fmt.Errorf("payload mismatch at offset %d", read) + } + read += int64(n) + } + if err := <-errCh; err != nil { + return err + } + return nil +} + +func fillPattern(buf []byte, offset int64) { + for i := range buf { + buf[i] = byte((offset + int64(i)*31 + 7) & 0xff) + } +}