diff --git a/cmd/olcrtc/main_test.go b/cmd/olcrtc/main_test.go index 18f4ddf..acb6a1d 100644 --- a/cmd/olcrtc/main_test.go +++ b/cmd/olcrtc/main_test.go @@ -13,6 +13,11 @@ import ( var errBoom = errors.New("boom") +const ( + testAuthWBStream = "wbstream" + testDNSServer = "1.1.1.1:53" +) + func writeYAML(t *testing.T, body string) string { t.Helper() dir := t.TempDir() @@ -39,12 +44,12 @@ func TestRunWithArgsRequiresConfig(t *testing.T) { func TestRunGenModeValidationErrors(t *testing.T) { session.RegisterDefaults() - if err := runWithConfig(loadedConfig{scfg: session.Config{Mode: "gen"}}); err == nil { + if err := runWithConfig(loadedConfig{scfg: session.Config{Mode: modeGen}}); err == nil { t.Fatal("runWithConfig(gen, no carrier) error = nil") } cfg := loadedConfig{scfg: session.Config{ - Mode: "gen", Auth: "wbstream", DNSServer: "1.1.1.1:53", + Mode: modeGen, Auth: testAuthWBStream, DNSServer: testDNSServer, }} if err := runWithConfig(cfg); err == nil { t.Fatal("runWithConfig(gen, amount=0) error = nil") @@ -58,7 +63,7 @@ func TestRunGenModeCallsGen(t *testing.T) { oldRunGen := runGen t.Cleanup(func() { runGen = oldRunGen }) runGen = func(scfg session.Config) error { - if scfg.Auth != "wbstream" || scfg.DNSServer != "1.1.1.1:53" || scfg.Amount != 3 { + if scfg.Auth != testAuthWBStream || scfg.DNSServer != testDNSServer || scfg.Amount != 3 { t.Fatalf("runGen scfg = %+v", scfg) } collected = append(collected, "ok") @@ -66,7 +71,7 @@ func TestRunGenModeCallsGen(t *testing.T) { } cfg := loadedConfig{scfg: session.Config{ - Mode: "gen", Auth: "wbstream", DNSServer: "1.1.1.1:53", Amount: 3, + Mode: modeGen, Auth: testAuthWBStream, DNSServer: testDNSServer, Amount: 3, }} if err := runWithConfig(cfg); err != nil { t.Fatalf("runWithConfig(gen) error = %v", err) diff --git a/internal/client/client.go b/internal/client/client.go index a793945..349e5e4 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -265,7 +265,7 @@ func openControlStreamTimeout( _ = stream.SetDeadline(time.Time{}) if err != nil { _ = stream.Close() - return nil, "", err + return nil, "", fmt.Errorf("handshake client: %w", err) } return stream, sid, nil } @@ -284,6 +284,7 @@ func resolveDeviceID(deviceID, path string) (string, error) { if path == "" { return uuid.NewString(), nil } + // #nosec G304 -- persistent device ID path is explicit user configuration. data, err := os.ReadFile(path) if err == nil { id := strings.TrimSpace(string(data)) @@ -294,7 +295,7 @@ func resolveDeviceID(deviceID, path string) (string, error) { return "", fmt.Errorf("read device id %s: %w", path, err) } id := uuid.NewString() - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil { return "", fmt.Errorf("mkdir device id dir: %w", err) } if err := os.WriteFile(path, []byte(id+"\n"), 0o600); err != nil { diff --git a/internal/client/client_test.go b/internal/client/client_test.go index ebe6745..48976fe 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -18,6 +18,11 @@ import ( var errUnexpectedConnectRequest = errors.New("unexpected connect request") +const ( + testConnectCommand = "connect" + testConnectHost = "example.com" +) + func TestSetupCipher(t *testing.T) { keyHex := "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" cipher, err := setupCipher(keyHex) @@ -384,7 +389,6 @@ func TestReadSocks5AddrReadErrors(t *testing.T) { } } -//nolint:cyclop // table-driven test naturally has many branches func TestSendConnectRequestOverSmux(t *testing.T) { a, b := net.Pipe() defer func() { @@ -417,7 +421,7 @@ func TestSendConnectRequestOverSmux(t *testing.T) { done <- err return } - if req["cmd"] != "connect" || req["addr"] != "example.com" { //nolint:goconst,lll // test literal, repetition is intentional + if req["cmd"] != testConnectCommand || req["addr"] != testConnectHost { done <- errUnexpectedConnectRequest return } @@ -432,7 +436,7 @@ func TestSendConnectRequestOverSmux(t *testing.T) { defer func() { _ = stream.Close() }() c := &Client{deviceID: "client-1"} - if err := c.sendConnectRequest(stream, "example.com", 443); err != nil { + if err := c.sendConnectRequest(stream, testConnectHost, 443); err != nil { t.Fatalf("sendConnectRequest() error = %v", err) } if err := <-done; err != nil { diff --git a/internal/config/config.go b/internal/config/config.go index 9fcad0a..49b0f60 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,6 +5,8 @@ // [Apply] to merge a parsed [File] onto an existing [session.Config]; // non-zero fields in the session config (typically populated from CLI flags) // take precedence over the YAML values. +// +//nolint:tagliatelle // YAML keys are the documented config file schema. package config import ( @@ -21,21 +23,21 @@ var ErrConfigNotFound = errors.New("config file not found") // File is the on-disk YAML schema. type File struct { - Mode string `yaml:"mode"` - Link string `yaml:"link"` - Auth Auth `yaml:"auth"` - Room Room `yaml:"room"` - Crypto Crypto `yaml:"crypto"` - Net Net `yaml:"net"` - SOCKS SOCKS `yaml:"socks"` - Engine Engine `yaml:"engine"` - Video Video `yaml:"video"` - VP8 VP8 `yaml:"vp8"` - SEI SEI `yaml:"sei"` - Gen Gen `yaml:"gen"` - Data string `yaml:"data"` - Debug bool `yaml:"debug"` - FFmpeg string `yaml:"ffmpeg"` + Mode string `yaml:"mode"` + Link string `yaml:"link"` + Auth Auth `yaml:"auth"` + Room Room `yaml:"room"` + Crypto Crypto `yaml:"crypto"` + Net Net `yaml:"net"` + SOCKS SOCKS `yaml:"socks"` + Engine Engine `yaml:"engine"` + Video Video `yaml:"video"` + VP8 VP8 `yaml:"vp8"` + SEI SEI `yaml:"sei"` + Gen Gen `yaml:"gen"` + Data string `yaml:"data"` + Debug bool `yaml:"debug"` + FFmpeg string `yaml:"ffmpeg"` } // Auth selects the auth provider. @@ -111,6 +113,7 @@ type Gen struct { // Load parses a YAML file from disk. func Load(path string) (File, error) { + // #nosec G304 -- config path is an explicit CLI/user input. data, err := os.ReadFile(path) if err != nil { if errors.Is(err, os.ErrNotExist) { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 6c402b2..95c4d9b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -8,6 +8,13 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/app/session" ) +const ( + testModeSrv = "srv" + testAuthProvider = "wbstream" + testRoomID = "r1" + testCryptoKey = "deadbeef" +) + func TestLoadAndApply(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "olcrtc.yaml") @@ -43,18 +50,48 @@ debug: true if err != nil { t.Fatalf("Load: %v", err) } - if f.Mode != "srv" || f.Auth.Provider != "wbstream" || f.Room.ID != "r1" || f.Crypto.Key != "deadbeef" { - t.Fatalf("unexpected file: %+v", f) - } + requireLoadedFile(t, f) got := Apply(session.Config{}, f) - if got.Mode != "srv" || got.Link != "direct" || got.Auth != "wbstream" || - got.RoomID != "r1" || got.KeyHex != "deadbeef" || - got.Transport != "datachannel" || got.DNSServer != "1.1.1.1:53" || - got.SOCKSHost != "127.0.0.1" || got.SOCKSPort != 1080 || - got.SOCKSUser != "u" || got.SOCKSPass != "p" || - got.VP8FPS != 25 || got.VP8BatchSize != 4 || got.Amount != 3 { - t.Fatalf("Apply produced wrong config: %+v", got) + requireAppliedConfig(t, got) +} + +func requireLoadedFile(t *testing.T, f File) { + t.Helper() + if f.Mode != testModeSrv { + t.Fatalf("Mode = %q, want %q", f.Mode, testModeSrv) + } + if f.Auth.Provider != testAuthProvider { + t.Fatalf("Auth.Provider = %q, want %q", f.Auth.Provider, testAuthProvider) + } + if f.Room.ID != testRoomID { + t.Fatalf("Room.ID = %q, want %q", f.Room.ID, testRoomID) + } + if f.Crypto.Key != testCryptoKey { + t.Fatalf("Crypto.Key = %q, want %q", f.Crypto.Key, testCryptoKey) + } +} + +func requireAppliedConfig(t *testing.T, got session.Config) { + t.Helper() + want := session.Config{ + Mode: testModeSrv, + Link: "direct", + Auth: testAuthProvider, + RoomID: testRoomID, + KeyHex: testCryptoKey, + Transport: "datachannel", + DNSServer: "1.1.1.1:53", + SOCKSHost: "127.0.0.1", + SOCKSPort: 1080, + SOCKSUser: "u", + SOCKSPass: "p", + VP8FPS: 25, + VP8BatchSize: 4, + Amount: 3, + } + if got != want { + t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want) } } @@ -65,7 +102,7 @@ func TestApplyCLIWins(t *testing.T) { SOCKSPort: 9999, } f := File{ - Mode: "srv", + Mode: testModeSrv, Crypto: Crypto{Key: "from-yaml"}, SOCKS: SOCKS{Port: 1234, Host: "0.0.0.0"}, } diff --git a/internal/e2e/tunnel_test.go b/internal/e2e/tunnel_test.go index 8be24c4..a3cfb0b 100644 --- a/internal/e2e/tunnel_test.go +++ b/internal/e2e/tunnel_test.go @@ -27,7 +27,18 @@ import ( "github.com/pion/webrtc/v4" ) -const testKeyHex = "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" +const ( + testKeyHex = "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" + transportData = "datachannel" + transportVideo = "videochannel" + transportSEI = "seichannel" + transportVP8 = "vp8channel" + linkDirect = "direct" + testRoom = "room" + localDNSServer = "127.0.0.1:53" + videoHWNone = "none" + testClientDeviceID = "client-1" +) var ( errRealE2ENotReady = errors.New("real e2e client did not become ready") @@ -330,16 +341,16 @@ func builtInCarrierNames() []string { } func builtInTransportNames() []string { - return []string{"datachannel", "videochannel", "seichannel", "vp8channel"} + return []string{transportData, transportVideo, transportSEI, transportVP8} } func realE2ECaseExpectation(carrierName, transportName string) realE2EExpectation { switch carrierName { case "telemost": switch transportName { - case "vp8channel": + case transportVP8: return realE2EExpectPass - case "videochannel": + case transportVideo: return realE2EBestEffort default: return realE2EExpectFail @@ -347,7 +358,7 @@ func realE2ECaseExpectation(carrierName, transportName string) realE2EExpectatio case "wbstream": return realE2EExpectPass case "jazz": - if transportName == "datachannel" { + if transportName == transportData { return realE2EExpectFail } return realE2EExpectPass @@ -362,8 +373,10 @@ func realE2EExpectationLabel(expectation realE2EExpectation) string { return "SUCCESS" case realE2EBestEffort: return "BEST EFFORT" - default: + case realE2EExpectFail: return "EXPECTED FAIL" + default: + return "UNKNOWN" } } @@ -377,31 +390,31 @@ func TestRealE2ECaseExpectation(t *testing.T) { { name: "jazz datachannel is expected to fail", carrier: "jazz", - transport: "datachannel", + transport: transportData, want: realE2EExpectFail, }, { name: "jazz videochannel is expected to pass", carrier: "jazz", - transport: "videochannel", + transport: transportVideo, want: realE2EExpectPass, }, { name: "telemost datachannel is expected to fail", carrier: "telemost", - transport: "datachannel", + transport: transportData, want: realE2EExpectFail, }, { name: "telemost vp8channel is expected to pass", carrier: "telemost", - transport: "vp8channel", + transport: transportVP8, want: realE2EExpectPass, }, { name: "wbstream datachannel is expected to pass", carrier: "wbstream", - transport: "datachannel", + transport: transportData, want: realE2EExpectPass, }, } @@ -474,19 +487,19 @@ func requireRealRoom(ctx context.Context, t *testing.T, carrierName string) stri func validSessionConfig(mode, carrierName, transportName string) session.Config { return session.Config{ Mode: mode, - Link: "direct", + Link: linkDirect, Transport: transportName, Auth: carrierName, - RoomID: "room", + RoomID: testRoom, KeyHex: testKeyHex, SOCKSHost: "127.0.0.1", SOCKSPort: 1080, - DNSServer: "127.0.0.1:53", + DNSServer: localDNSServer, VideoWidth: 1080, VideoHeight: 1080, VideoFPS: 30, VideoBitrate: "1M", - VideoHW: "none", + VideoHW: videoHWNone, VideoCodec: "tile", VideoTileModule: 4, VideoTileRS: 20, @@ -504,7 +517,7 @@ func validLinkConfig(carrierName, transportName string) link.Config { return link.Config{ Transport: cfg.Transport, Carrier: cfg.Auth, - RoomURL: "room", + RoomURL: testRoom, DeviceID: "e2e-link-test", Name: "e2e-" + carrierName + "-" + transportName, DNSServer: cfg.DNSServer, @@ -583,7 +596,7 @@ type tunnelRuntime struct { stopWait time.Duration } -func startTunnel(t *testing.T, deviceID, _ string) *tunnelRuntime { +func startTunnel(t *testing.T) *tunnelRuntime { t.Helper() carrierName, room := registerMemoryCarrier(t) @@ -594,12 +607,12 @@ func startTunnel(t *testing.T, deviceID, _ string) *tunnelRuntime { serverErr := make(chan error, 1) go func() { serverErr <- server.Run(ctx, server.Config{ - Link: "direct", - Transport: "datachannel", + Link: linkDirect, + Transport: transportData, Carrier: carrierName, - RoomURL: "room", + RoomURL: testRoom, KeyHex: testKeyHex, - DNSServer: "127.0.0.1:53", + DNSServer: localDNSServer, }) }() room.waitConnected(t, 1) @@ -608,14 +621,14 @@ func startTunnel(t *testing.T, deviceID, _ string) *tunnelRuntime { clientErr := make(chan error, 1) go func() { clientErr <- client.RunWithReady(ctx, client.Config{ - Link: "direct", - Transport: "datachannel", + Link: linkDirect, + Transport: transportData, Carrier: carrierName, - RoomURL: "room", + RoomURL: testRoom, KeyHex: testKeyHex, - DeviceID: deviceID, + DeviceID: testClientDeviceID, LocalAddr: socksAddr, - DNSServer: "127.0.0.1:53", + DNSServer: localDNSServer, }, func() { close(ready) }) }() waitForReady(t, ready) @@ -646,17 +659,17 @@ func startRealTunnel( serverErr := make(chan error, 1) go func() { serverErr <- server.Run(runCtx, server.Config{ - Link: "direct", + Link: linkDirect, Transport: transportName, Carrier: carrierName, RoomURL: roomURL, KeyHex: testKeyHex, - DNSServer: "127.0.0.1:53", + DNSServer: localDNSServer, VideoWidth: 1080, VideoHeight: 1080, VideoFPS: 60, VideoBitrate: "5000k", - VideoHW: "none", + VideoHW: videoHWNone, VideoQRSize: 512, VideoQRRecovery: "low", VideoCodec: "qrcode", @@ -685,19 +698,19 @@ func startRealTunnel( clientErr := make(chan error, 1) go func() { clientErr <- client.RunWithReady(runCtx, client.Config{ - Link: "direct", + Link: linkDirect, Transport: transportName, Carrier: carrierName, RoomURL: roomURL, KeyHex: testKeyHex, DeviceID: clientDeviceID, LocalAddr: socksAddr, - DNSServer: "127.0.0.1:53", + DNSServer: localDNSServer, VideoWidth: 1080, VideoHeight: 1080, VideoFPS: 60, VideoBitrate: "5000k", - VideoHW: "none", + VideoHW: videoHWNone, VideoQRSize: 512, VideoQRRecovery: "low", VideoCodec: "qrcode", @@ -869,7 +882,7 @@ func TestDirectLinkCreatesAllProviderTransportCombinations(t *testing.T) { t.Run(carrierName, func(t *testing.T) { for _, transportName := range builtInTransportNames() { t.Run(transportName, func(t *testing.T) { - ln, err := link.New(context.Background(), "direct", validLinkConfig(carrierName, transportName)) + ln, err := link.New(context.Background(), linkDirect, validLinkConfig(carrierName, transportName)) if err != nil { t.Fatalf("link.New() error = %v", err) } @@ -891,9 +904,9 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) { for _, carrierName := range builtInCarrierNames() { t.Run(carrierName, func(t *testing.T) { - for _, transportName := range []string{"datachannel", "seichannel"} { + for _, transportName := range []string{transportData, transportSEI} { t.Run(transportName, func(t *testing.T) { - ln, err := link.New(context.Background(), "direct", validLinkConfig(carrierName, transportName)) + ln, err := link.New(context.Background(), linkDirect, validLinkConfig(carrierName, transportName)) if err != nil { t.Fatalf("link.New() error = %v", err) } @@ -964,7 +977,7 @@ func runRealE2ECase(t *testing.T, carrierName, transportName, roomURL, echoAddr ctx, cancel := context.WithTimeout(context.Background(), *realE2ETimeout) defer cancel() - rt, err := startRealTunnel(ctx, t, carrierName, transportName, roomURL, "client-1", "client-1") + rt, err := startRealTunnel(ctx, t, carrierName, transportName, roomURL, testClientDeviceID, testClientDeviceID) if err != nil { return err } @@ -999,7 +1012,7 @@ func runRealE2ECase(t *testing.T, carrierName, transportName, roomURL, echoAddr func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { echoAddr := startEchoServer(t) - rt := startTunnel(t, "client-1", "client-1") + rt := startTunnel(t) defer rt.stop(t) conn := connectViaSOCKS(t, rt.socksAddr, echoAddr) @@ -1023,7 +1036,7 @@ func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) { echoAddr := startEchoServer(t) - rt := startTunnel(t, "client-1", "client-1") + rt := startTunnel(t) defer rt.stop(t) for i := range 5 { @@ -1050,7 +1063,7 @@ func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) { } func TestEndedCallbackStopsClientAndServer(t *testing.T) { - rt := startTunnel(t, "client-1", "client-1") + rt := startTunnel(t) rt.room.triggerEnded("conference ended") rt.waitStopped(t) } @@ -1144,7 +1157,7 @@ func tryConnectViaSOCKS(socksAddr, targetAddr string) (net.Conn, error) { func TestLargeTransferOverTunnel(t *testing.T) { echoAddr := startEchoServer(t) - rt := startTunnel(t, "client-1", "client-1") + rt := startTunnel(t) defer rt.stop(t) size := int64(32 << 20) diff --git a/internal/handshake/handshake.go b/internal/handshake/handshake.go index 9d66f15..bec84a7 100644 --- a/internal/handshake/handshake.go +++ b/internal/handshake/handshake.go @@ -15,6 +15,8 @@ // After the exchange the control stream stays open; tunnel traffic flows over // additional smux streams opened by the client. The control stream may carry // keepalives or future control messages. +// +//nolint:tagliatelle // JSON keys are the stable wire protocol schema. package handshake import ( @@ -117,27 +119,37 @@ func Client(rw io.ReadWriter, deviceID string, claims map[string]any) (string, e } switch probe.Type { + case TypeHello: + return "", fmt.Errorf("%w: got %q", ErrUnexpectedMessage, probe.Type) case TypeWelcome: - var w Welcome - if err := json.Unmarshal(raw, &w); err != nil { - return "", fmt.Errorf("parse welcome: %w", err) - } - if w.Version != ProtoVersion { - return "", fmt.Errorf("%w: server v%d, client v%d", - ErrProtocolVersion, w.Version, ProtoVersion) - } - return w.SessionID, nil + return parseWelcome(raw) case TypeReject: - var r Reject - if err := json.Unmarshal(raw, &r); err != nil { - return "", fmt.Errorf("parse reject: %w", err) - } - return "", fmt.Errorf("%w: %s", ErrRejected, r.Reason) + return parseReject(raw) default: return "", fmt.Errorf("%w: got %q", ErrUnexpectedMessage, probe.Type) } } +func parseWelcome(raw []byte) (string, error) { + var w Welcome + if err := json.Unmarshal(raw, &w); err != nil { + return "", fmt.Errorf("parse welcome: %w", err) + } + if w.Version != ProtoVersion { + return "", fmt.Errorf("%w: server v%d, client v%d", + ErrProtocolVersion, w.Version, ProtoVersion) + } + return w.SessionID, nil +} + +func parseReject(raw []byte) (string, error) { + var r Reject + if err := json.Unmarshal(raw, &r); err != nil { + return "", fmt.Errorf("parse reject: %w", err) + } + return "", fmt.Errorf("%w: %s", ErrRejected, r.Reason) +} + // Server performs the server side of the handshake. It reads CLIENT_HELLO, // invokes auth, and writes the corresponding WELCOME or REJECT. On success it // returns the parsed Hello and the session ID produced by auth. diff --git a/internal/handshake/handshake_test.go b/internal/handshake/handshake_test.go index 790192b..e575ed1 100644 --- a/internal/handshake/handshake_test.go +++ b/internal/handshake/handshake_test.go @@ -8,6 +8,10 @@ import ( "testing" ) +const testSessionID = "sess-42" + +var errNope = errors.New("nope") + func pair(t *testing.T) (net.Conn, net.Conn) { t.Helper() a, b := net.Pipe() @@ -29,12 +33,12 @@ func TestHandshakeRoundTrip(t *testing.T) { if claims["plan"] != "pro" { t.Errorf("claims = %v", claims) } - return "sess-42", nil + return testSessionID, nil }) if err != nil { t.Errorf("Server: %v", err) } - if hello.DeviceID != "dev-1" || sid != "sess-42" { + if hello.DeviceID != "dev-1" || sid != testSessionID { t.Errorf("Server returned hello=%+v sid=%q", hello, sid) } }() @@ -43,7 +47,7 @@ func TestHandshakeRoundTrip(t *testing.T) { if err != nil { t.Fatalf("Client: %v", err) } - if sid != "sess-42" { + if sid != testSessionID { t.Fatalf("session id = %q, want sess-42", sid) } } @@ -53,7 +57,7 @@ func TestHandshakeRejected(t *testing.T) { go func() { _, _, _ = Server(sConn, func(string, map[string]any) (string, error) { - return "", errors.New("nope") + return "", errNope }) }() diff --git a/internal/server/server.go b/internal/server/server.go index bc2f557..ab00cc0 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -385,10 +385,8 @@ func (s *Server) onData(data []byte) { // streams are tunnel streams and proxy traffic. func (s *Server) serve(ctx context.Context) { for { - select { - case <-ctx.Done(): + if contextDone(ctx) { return - default: } s.sessMu.RLock() @@ -411,10 +409,8 @@ func (s *Server) serve(ctx context.Context) { stream, err := sess.AcceptStream() if err != nil { - select { - case <-ctx.Done(): + if contextDone(ctx) { return - default: } logger.Debugf("AcceptStream returned %v - reinstalling session", err) s.reinstallSession(sess) @@ -429,6 +425,15 @@ func (s *Server) serve(ctx context.Context) { } } +func contextDone(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} + // handshakeReady reports whether the current session has completed its // handshake. The session is reset on reconnect, so this is recomputed. func (s *Server) handshakeReady() bool { @@ -568,7 +573,7 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) { go func() { n, _ := io.Copy(stream, conn) if n > 0 { - bytesOut = uint64(n) //nolint:gosec // io.Copy returns non-negative int64 + bytesOut = uint64(n) } _ = stream.Close() close(done) @@ -578,7 +583,7 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) { <-done bytesIn := uint64(0) if in > 0 { - bytesIn = uint64(in) //nolint:gosec // io.Copy returns non-negative int64 + bytesIn = uint64(in) } if s.onTraffic != nil { s.onTraffic(sid, addr, bytesIn, bytesOut) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 59c0846..f6034bf 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -16,6 +16,11 @@ import ( "github.com/xtaci/smux" ) +const ( + testConnectAddr = "127.0.0.1" + testConnectCmd = connectCommand +) + func TestSetupCipher(t *testing.T) { keyHex := "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" cipher, err := setupCipher(keyHex) @@ -48,7 +53,7 @@ func TestSmuxConfig(t *testing.T) { func TestParseConnectRequest(t *testing.T) { buf, err := json.Marshal(ConnectRequest{ - Cmd: "connect", + Cmd: testConnectCmd, Addr: "example.com", //nolint:goconst // test literal, repetition is intentional Port: 443, }) @@ -249,7 +254,7 @@ func TestDialWithoutProxy(t *testing.T) { t.Fatalf("listener addr type = %T, want *net.TCPAddr", ln.Addr()) } s := &Server{resolver: net.DefaultResolver} - conn, err := s.dial(ConnectRequest{Addr: "127.0.0.1", Port: tcpAddr.Port}) + conn, err := s.dial(ConnectRequest{Addr: testConnectAddr, Port: tcpAddr.Port}) if err != nil { t.Fatalf("dial() error = %v", err) } @@ -258,7 +263,7 @@ func TestDialWithoutProxy(t *testing.T) { } func TestDialProxyError(t *testing.T) { - s := &Server{socksProxyAddr: "127.0.0.1", socksProxyPort: 1} + s := &Server{socksProxyAddr: testConnectAddr, socksProxyPort: 1} if _, err := s.dial(ConnectRequest{Addr: "example.com", Port: 443}); err == nil || !strings.Contains(err.Error(), "failed to dial proxy") { //nolint:lll // long test description t.Fatalf("dial() error = %v", err) } @@ -333,8 +338,8 @@ func TestHandleStreamDispatchAfterConnect(t *testing.T) { t.Fatalf("OpenStream() error = %v", err) } req, err := json.Marshal(ConnectRequest{ - Cmd: "connect", - Addr: "127.0.0.1", + Cmd: testConnectCmd, + Addr: testConnectAddr, Port: 1, // unreachable port — dispatch will fail dial and exit }) if err != nil { @@ -368,8 +373,10 @@ func TestReinstallSessionFiresOnClose(t *testing.T) { } } +//nolint:cyclop // integration-style test needs setup, proxying, and traffic assertions together. func TestDispatchFiresOnTraffic(t *testing.T) { - ln, err := net.Listen("tcp4", "127.0.0.1:0") + var lc net.ListenConfig + ln, err := lc.Listen(context.Background(), "tcp4", testConnectAddr+":0") if err != nil { t.Fatalf("Listen() error = %v", err) } @@ -403,9 +410,9 @@ func TestDispatchFiresOnTraffic(t *testing.T) { defer func() { _ = clientSess.Close() }() var rec struct { - sid string - addr string - in, out uint64 + sid string + addr string + in, out uint64 } recChan := make(chan struct{}) s := &Server{ @@ -437,8 +444,8 @@ func TestDispatchFiresOnTraffic(t *testing.T) { t.Fatalf("addr type = %T", ln.Addr()) } req, err := json.Marshal(ConnectRequest{ - Cmd: "connect", - Addr: "127.0.0.1", + Cmd: testConnectCmd, + Addr: testConnectAddr, Port: tcpAddr.Port, }) if err != nil { diff --git a/mobile/mobile_test.go b/mobile/mobile_test.go index 541fba5..f22625b 100644 --- a/mobile/mobile_test.go +++ b/mobile/mobile_test.go @@ -173,8 +173,11 @@ func TestStartWithInjectedRunnerLifecycle(t *testing.T) { if cfg.Link != defaultLink || cfg.Transport != dataTransport || cfg.Carrier != carrierJazz || cfg.RoomURL != "any" || cfg.DeviceID != "client" || cfg.LocalAddr != "127.0.0.1:1080" || cfg.DNSServer != defaultDNSServer || cfg.VP8FPS != 60 || cfg.VP8BatchSize != 8 { - t.Fatalf("RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d", - cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.DeviceID, cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize) + t.Fatalf( + "RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d", + cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.DeviceID, + cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize, + ) } onReady() <-ctx.Done() diff --git a/pkg/olcrtc/tunnel/tunnel_test.go b/pkg/olcrtc/tunnel/tunnel_test.go index c1366a0..17beeb6 100644 --- a/pkg/olcrtc/tunnel/tunnel_test.go +++ b/pkg/olcrtc/tunnel/tunnel_test.go @@ -8,6 +8,8 @@ import ( "github.com/openlibrecommunity/olcrtc/pkg/olcrtc/tunnel" ) +var errNo = errors.New("no") + func TestRun_FailsWithoutKey(t *testing.T) { tunnel.RegisterDefaults() err := tunnel.New(tunnel.Config{ @@ -22,15 +24,14 @@ func TestRun_FailsWithoutKey(t *testing.T) { } } -func TestRun_PropagatesAuthHook(t *testing.T) { +func TestRun_PropagatesAuthHook(_ *testing.T) { tunnel.RegisterDefaults() - sentinel := errors.New("no") var called bool cfg := tunnel.Config{ AuthHook: func(string, map[string]any) (string, error) { called = true - return "", sentinel + return "", errNo }, } _ = tunnel.New(cfg).Run(context.Background())