diff --git a/internal/control/control.go b/internal/control/control.go index 7d82f04..24b2974 100644 --- a/internal/control/control.go +++ b/internal/control/control.go @@ -12,13 +12,14 @@ package control import ( "context" - "encoding/binary" "encoding/json" "errors" "fmt" "io" "sync" "time" + + "github.com/openlibrecommunity/olcrtc/internal/framing" ) const ( @@ -53,7 +54,7 @@ var ( // ErrUnexpectedMessage is returned for unknown or malformed control message types. ErrUnexpectedMessage = errors.New("unexpected control message") // ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize]. - ErrFrameTooLarge = errors.New("control frame too large") + ErrFrameTooLarge = framing.ErrFrameTooLarge ) // Message is one control-stream frame. @@ -308,36 +309,9 @@ func parseMessage(raw []byte) (Message, error) { } func writeFrame(w io.Writer, msg Message) error { - body, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("marshal control message: %w", err) - } - if len(body) > MaxMessageSize { - return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize) - } - var hdr [4]byte - binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize - if _, err := w.Write(hdr[:]); err != nil { - return fmt.Errorf("write control hdr: %w", err) - } - if _, err := w.Write(body); err != nil { - return fmt.Errorf("write control body: %w", err) - } - return nil + return framing.WriteJSON(w, msg, MaxMessageSize) } func readFrame(r io.Reader) ([]byte, error) { - var hdr [4]byte - if _, err := io.ReadFull(r, hdr[:]); err != nil { - return nil, fmt.Errorf("read control hdr: %w", err) - } - n := binary.BigEndian.Uint32(hdr[:]) - if n > MaxMessageSize { - return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize) - } - buf := make([]byte, n) - if _, err := io.ReadFull(r, buf); err != nil { - return nil, fmt.Errorf("read control body: %w", err) - } - return buf, nil + return framing.ReadBytes(r, MaxMessageSize) } diff --git a/internal/framing/framing.go b/internal/framing/framing.go new file mode 100644 index 0000000..b73de24 --- /dev/null +++ b/internal/framing/framing.go @@ -0,0 +1,60 @@ +// Package framing implements the length-prefixed JSON message framing used by +// the olcrtc control and handshake protocols. +// +// Wire format: 4-byte big-endian length followed by that many bytes of body. +// Body interpretation (JSON, protobuf, etc.) is up to the caller; this package +// only deals with byte-level framing. +package framing + +import ( + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" +) + +// ErrFrameTooLarge is returned when a frame exceeds the configured max size. +var ErrFrameTooLarge = errors.New("frame too large") + +// WriteJSON marshals msg as JSON and writes it framed. +func WriteJSON(w io.Writer, msg any, maxSize int) error { + body, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + return WriteBytes(w, body, maxSize) +} + +// WriteBytes writes body as a single length-prefixed frame. +func WriteBytes(w io.Writer, body []byte, maxSize int) error { + if maxSize > 0 && len(body) > maxSize { + return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), maxSize) + } + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // size bounded by maxSize check + if _, err := w.Write(hdr[:]); err != nil { + return fmt.Errorf("write hdr: %w", err) + } + if _, err := w.Write(body); err != nil { + return fmt.Errorf("write body: %w", err) + } + return nil +} + +// ReadBytes reads one length-prefixed frame from r. +func ReadBytes(r io.Reader, maxSize int) ([]byte, error) { + var hdr [4]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return nil, fmt.Errorf("read hdr: %w", err) + } + n := binary.BigEndian.Uint32(hdr[:]) + if maxSize > 0 && n > uint32(maxSize) { //nolint:gosec // maxSize is non-negative + return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, maxSize) + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, fmt.Errorf("read body: %w", err) + } + return buf, nil +} diff --git a/internal/framing/framing_test.go b/internal/framing/framing_test.go new file mode 100644 index 0000000..1793bf7 --- /dev/null +++ b/internal/framing/framing_test.go @@ -0,0 +1,77 @@ +package framing_test + +import ( + "bytes" + "errors" + "io" + "strings" + "testing" + + "github.com/openlibrecommunity/olcrtc/internal/framing" +) + +func TestRoundTripJSON(t *testing.T) { + var buf bytes.Buffer + type msg struct { + Type string `json:"type"` + N int `json:"n"` + } + in := msg{Type: "ping", N: 7} + if err := framing.WriteJSON(&buf, in, 1024); err != nil { + t.Fatalf("write: %v", err) + } + body, err := framing.ReadBytes(&buf, 1024) + if err != nil { + t.Fatalf("read: %v", err) + } + want := `{"type":"ping","n":7}` + if string(body) != want { + t.Fatalf("body=%q want=%q", body, want) + } +} + +func TestWriteTooLarge(t *testing.T) { + var buf bytes.Buffer + err := framing.WriteBytes(&buf, []byte(strings.Repeat("x", 10)), 5) + if !errors.Is(err, framing.ErrFrameTooLarge) { + t.Fatalf("want ErrFrameTooLarge, got %v", err) + } +} + +func TestReadTooLarge(t *testing.T) { + var buf bytes.Buffer + // Manually craft an oversized header. + buf.Write([]byte{0x00, 0x00, 0x10, 0x00}) // 4096 + _, err := framing.ReadBytes(&buf, 1024) + if !errors.Is(err, framing.ErrFrameTooLarge) { + t.Fatalf("want ErrFrameTooLarge, got %v", err) + } +} + +func TestReadTruncated(t *testing.T) { + var buf bytes.Buffer + buf.Write([]byte{0x00, 0x00, 0x00, 0x04}) + buf.WriteByte(0x41) // only 1 of 4 body bytes + _, err := framing.ReadBytes(&buf, 1024) + if err == nil || errors.Is(err, framing.ErrFrameTooLarge) { + t.Fatalf("want EOF/unexpected, got %v", err) + } + if !errors.Is(err, io.ErrUnexpectedEOF) { + t.Fatalf("want UnexpectedEOF, got %v", err) + } +} + +func TestZeroMaxAllowsAnything(t *testing.T) { + var buf bytes.Buffer + big := bytes.Repeat([]byte{0xAA}, 100_000) + if err := framing.WriteBytes(&buf, big, 0); err != nil { + t.Fatalf("write: %v", err) + } + got, err := framing.ReadBytes(&buf, 0) + if err != nil { + t.Fatalf("read: %v", err) + } + if !bytes.Equal(got, big) { + t.Fatalf("roundtrip mismatch") + } +} diff --git a/internal/handshake/handshake.go b/internal/handshake/handshake.go index 5d34f6f..2399c76 100644 --- a/internal/handshake/handshake.go +++ b/internal/handshake/handshake.go @@ -20,12 +20,13 @@ package handshake import ( - "encoding/binary" "encoding/json" "errors" "fmt" "io" "time" + + "github.com/openlibrecommunity/olcrtc/internal/framing" ) // ProtoVersion identifies the wire-format version. Bumped only on breaking @@ -84,7 +85,7 @@ var ( // ErrUnexpectedMessage is returned when a peer sends the wrong message type. ErrUnexpectedMessage = errors.New("unexpected handshake message") // ErrFrameTooLarge is returned when a peer announces a frame above [MaxMessageSize]. - ErrFrameTooLarge = errors.New("handshake frame too large") + ErrFrameTooLarge = framing.ErrFrameTooLarge ) // AuthFunc is invoked by [Server] after parsing CLIENT_HELLO. @@ -191,36 +192,9 @@ func Server(rw io.ReadWriter, auth AuthFunc) (Hello, string, error) { } func writeFrame(w io.Writer, msg any) error { - body, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("marshal: %w", err) - } - if len(body) > MaxMessageSize { - return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize) - } - var hdr [4]byte - binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize - if _, err := w.Write(hdr[:]); err != nil { - return fmt.Errorf("write hdr: %w", err) - } - if _, err := w.Write(body); err != nil { - return fmt.Errorf("write body: %w", err) - } - return nil + return framing.WriteJSON(w, msg, MaxMessageSize) } func readFrame(r io.Reader) ([]byte, error) { - var hdr [4]byte - if _, err := io.ReadFull(r, hdr[:]); err != nil { - return nil, fmt.Errorf("read hdr: %w", err) - } - n := binary.BigEndian.Uint32(hdr[:]) - if n > MaxMessageSize { - return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize) - } - buf := make([]byte, n) - if _, err := io.ReadFull(r, buf); err != nil { - return nil, fmt.Errorf("read body: %w", err) - } - return buf, nil + return framing.ReadBytes(r, MaxMessageSize) }