mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
refactor: extract length-prefix framing into shared package
handshake and control duplicated the same 4-byte BE length + body framing with independent ErrFrameTooLarge constants. Centralize in internal/framing and have both callers delegate. ErrFrameTooLarge is re-exported so existing errors.Is checks keep working. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -12,13 +12,14 @@ package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/framing"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -53,7 +54,7 @@ var (
|
||||
// ErrUnexpectedMessage is returned for unknown or malformed control message types.
|
||||
ErrUnexpectedMessage = errors.New("unexpected control message")
|
||||
// ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize].
|
||||
ErrFrameTooLarge = errors.New("control frame too large")
|
||||
ErrFrameTooLarge = framing.ErrFrameTooLarge
|
||||
)
|
||||
|
||||
// Message is one control-stream frame.
|
||||
@@ -308,36 +309,9 @@ func parseMessage(raw []byte) (Message, error) {
|
||||
}
|
||||
|
||||
func writeFrame(w io.Writer, msg Message) error {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal control message: %w", err)
|
||||
}
|
||||
if len(body) > MaxMessageSize {
|
||||
return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize)
|
||||
}
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize
|
||||
if _, err := w.Write(hdr[:]); err != nil {
|
||||
return fmt.Errorf("write control hdr: %w", err)
|
||||
}
|
||||
if _, err := w.Write(body); err != nil {
|
||||
return fmt.Errorf("write control body: %w", err)
|
||||
}
|
||||
return nil
|
||||
return framing.WriteJSON(w, msg, MaxMessageSize)
|
||||
}
|
||||
|
||||
func readFrame(r io.Reader) ([]byte, error) {
|
||||
var hdr [4]byte
|
||||
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||
return nil, fmt.Errorf("read control hdr: %w", err)
|
||||
}
|
||||
n := binary.BigEndian.Uint32(hdr[:])
|
||||
if n > MaxMessageSize {
|
||||
return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize)
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return nil, fmt.Errorf("read control body: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
return framing.ReadBytes(r, MaxMessageSize)
|
||||
}
|
||||
|
||||
60
internal/framing/framing.go
Normal file
60
internal/framing/framing.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Package framing implements the length-prefixed JSON message framing used by
|
||||
// the olcrtc control and handshake protocols.
|
||||
//
|
||||
// Wire format: 4-byte big-endian length followed by that many bytes of body.
|
||||
// Body interpretation (JSON, protobuf, etc.) is up to the caller; this package
|
||||
// only deals with byte-level framing.
|
||||
package framing
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ErrFrameTooLarge is returned when a frame exceeds the configured max size.
|
||||
var ErrFrameTooLarge = errors.New("frame too large")
|
||||
|
||||
// WriteJSON marshals msg as JSON and writes it framed.
|
||||
func WriteJSON(w io.Writer, msg any, maxSize int) error {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
return WriteBytes(w, body, maxSize)
|
||||
}
|
||||
|
||||
// WriteBytes writes body as a single length-prefixed frame.
|
||||
func WriteBytes(w io.Writer, body []byte, maxSize int) error {
|
||||
if maxSize > 0 && len(body) > maxSize {
|
||||
return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), maxSize)
|
||||
}
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // size bounded by maxSize check
|
||||
if _, err := w.Write(hdr[:]); err != nil {
|
||||
return fmt.Errorf("write hdr: %w", err)
|
||||
}
|
||||
if _, err := w.Write(body); err != nil {
|
||||
return fmt.Errorf("write body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadBytes reads one length-prefixed frame from r.
|
||||
func ReadBytes(r io.Reader, maxSize int) ([]byte, error) {
|
||||
var hdr [4]byte
|
||||
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||
return nil, fmt.Errorf("read hdr: %w", err)
|
||||
}
|
||||
n := binary.BigEndian.Uint32(hdr[:])
|
||||
if maxSize > 0 && n > uint32(maxSize) { //nolint:gosec // maxSize is non-negative
|
||||
return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, maxSize)
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return nil, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
77
internal/framing/framing_test.go
Normal file
77
internal/framing/framing_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package framing_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/framing"
|
||||
)
|
||||
|
||||
func TestRoundTripJSON(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
type msg struct {
|
||||
Type string `json:"type"`
|
||||
N int `json:"n"`
|
||||
}
|
||||
in := msg{Type: "ping", N: 7}
|
||||
if err := framing.WriteJSON(&buf, in, 1024); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
body, err := framing.ReadBytes(&buf, 1024)
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
want := `{"type":"ping","n":7}`
|
||||
if string(body) != want {
|
||||
t.Fatalf("body=%q want=%q", body, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTooLarge(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := framing.WriteBytes(&buf, []byte(strings.Repeat("x", 10)), 5)
|
||||
if !errors.Is(err, framing.ErrFrameTooLarge) {
|
||||
t.Fatalf("want ErrFrameTooLarge, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTooLarge(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
// Manually craft an oversized header.
|
||||
buf.Write([]byte{0x00, 0x00, 0x10, 0x00}) // 4096
|
||||
_, err := framing.ReadBytes(&buf, 1024)
|
||||
if !errors.Is(err, framing.ErrFrameTooLarge) {
|
||||
t.Fatalf("want ErrFrameTooLarge, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTruncated(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte{0x00, 0x00, 0x00, 0x04})
|
||||
buf.WriteByte(0x41) // only 1 of 4 body bytes
|
||||
_, err := framing.ReadBytes(&buf, 1024)
|
||||
if err == nil || errors.Is(err, framing.ErrFrameTooLarge) {
|
||||
t.Fatalf("want EOF/unexpected, got %v", err)
|
||||
}
|
||||
if !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Fatalf("want UnexpectedEOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestZeroMaxAllowsAnything(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
big := bytes.Repeat([]byte{0xAA}, 100_000)
|
||||
if err := framing.WriteBytes(&buf, big, 0); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
got, err := framing.ReadBytes(&buf, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, big) {
|
||||
t.Fatalf("roundtrip mismatch")
|
||||
}
|
||||
}
|
||||
@@ -20,12 +20,13 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/framing"
|
||||
)
|
||||
|
||||
// ProtoVersion identifies the wire-format version. Bumped only on breaking
|
||||
@@ -84,7 +85,7 @@ var (
|
||||
// ErrUnexpectedMessage is returned when a peer sends the wrong message type.
|
||||
ErrUnexpectedMessage = errors.New("unexpected handshake message")
|
||||
// ErrFrameTooLarge is returned when a peer announces a frame above [MaxMessageSize].
|
||||
ErrFrameTooLarge = errors.New("handshake frame too large")
|
||||
ErrFrameTooLarge = framing.ErrFrameTooLarge
|
||||
)
|
||||
|
||||
// AuthFunc is invoked by [Server] after parsing CLIENT_HELLO.
|
||||
@@ -191,36 +192,9 @@ func Server(rw io.ReadWriter, auth AuthFunc) (Hello, string, error) {
|
||||
}
|
||||
|
||||
func writeFrame(w io.Writer, msg any) error {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
if len(body) > MaxMessageSize {
|
||||
return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize)
|
||||
}
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize
|
||||
if _, err := w.Write(hdr[:]); err != nil {
|
||||
return fmt.Errorf("write hdr: %w", err)
|
||||
}
|
||||
if _, err := w.Write(body); err != nil {
|
||||
return fmt.Errorf("write body: %w", err)
|
||||
}
|
||||
return nil
|
||||
return framing.WriteJSON(w, msg, MaxMessageSize)
|
||||
}
|
||||
|
||||
func readFrame(r io.Reader) ([]byte, error) {
|
||||
var hdr [4]byte
|
||||
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||
return nil, fmt.Errorf("read hdr: %w", err)
|
||||
}
|
||||
n := binary.BigEndian.Uint32(hdr[:])
|
||||
if n > MaxMessageSize {
|
||||
return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize)
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return nil, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
return framing.ReadBytes(r, MaxMessageSize)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user