Files
olcrtc/internal/control/control.go
2026-05-18 00:46:26 +03:00

350 lines
8.5 KiB
Go

// Package control implements the post-handshake control stream protocol.
//
// The control stream is the first smux stream after the olcrtc handshake. It
// stays inside the encrypted muxconn path, so ping/pong proves that the actual
// tunnel path still round-trips, not merely that the provider connection is up.
//
// Wire format matches the handshake framing: a 4-byte big-endian length
// followed by a JSON message.
//
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
package control
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/openlibrecommunity/olcrtc/internal/framing"
)
const (
// ProtoVersion identifies the control stream wire format.
ProtoVersion = 1
// MaxMessageSize caps one control frame.
MaxMessageSize = 16 * 1024
// DefaultInterval is the default interval between ping probes.
DefaultInterval = 10 * time.Second
// DefaultTimeout is the default time to wait for a pong.
DefaultTimeout = 5 * time.Second
// DefaultFailures is the default number of consecutive missed pongs before
// the stream is marked unhealthy.
DefaultFailures = 3
)
// MsgType labels a control message.
type MsgType string
const (
// TypePing is sent periodically to prove control-stream liveness.
TypePing MsgType = "CONTROL_PING"
// TypePong replies to a ping with the same sequence and timestamp.
TypePong MsgType = "CONTROL_PONG"
// TypeClose tells the peer this control session is intentionally closing.
TypeClose MsgType = "CONTROL_CLOSE"
)
var (
// ErrUnhealthy is returned when the stream misses too many pong replies.
ErrUnhealthy = errors.New("control stream unhealthy")
// ErrClosedByPeer is returned when the peer gracefully closes the control session.
ErrClosedByPeer = errors.New("control stream closed by peer")
// ErrProtocolVersion is returned when the peer announces an incompatible version.
ErrProtocolVersion = errors.New("incompatible control protocol version")
// ErrUnexpectedMessage is returned for unknown or malformed control message types.
ErrUnexpectedMessage = errors.New("unexpected control message")
// ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize].
ErrFrameTooLarge = framing.ErrFrameTooLarge
)
// Message is one control-stream frame.
type Message struct {
Version int `json:"version"`
Type MsgType `json:"type"`
Seq uint64 `json:"seq,omitempty"`
SentUnixNano int64 `json:"sent_unix_nano,omitempty"`
}
// Health is reported when a ping round trip completes.
type Health struct {
Seq uint64
RTT time.Duration
LastSeen time.Time
}
// Status is a point-in-time view of control-stream health maintained by
// callers that embed the control loop.
type Status struct {
SessionID string
LastPong time.Time
LastRTT time.Duration
MissedPongs int
Reconnects uint64
UnhealthyEvents uint64
LastUnhealthy time.Time
}
// Config controls the liveness loop.
type Config struct {
Interval time.Duration
Timeout time.Duration
Failures int
// OnPong is called after a matching pong is received.
OnPong func(Health)
// OnMissedPong is called when one or more outstanding pongs time out.
OnMissedPong func(missed int)
// OnUnhealthy is called before Run returns [ErrUnhealthy].
OnUnhealthy func(missed int)
}
func (cfg Config) withDefaults() Config {
if cfg.Interval <= 0 {
cfg.Interval = DefaultInterval
}
if cfg.Timeout <= 0 {
cfg.Timeout = DefaultTimeout
}
if cfg.Failures <= 0 {
cfg.Failures = DefaultFailures
}
return cfg
}
// Run drives bidirectional ping/pong liveness until ctx is canceled, rw closes,
// or the configured failure threshold is reached.
func Run(ctx context.Context, rw io.ReadWriteCloser, cfg Config) error {
cfg = cfg.withDefaults()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
state := &state{
rw: rw,
cfg: cfg,
pending: make(map[uint64]time.Time),
now: time.Now,
out: make(chan Message, 16),
}
errCh := make(chan error, 3)
go func() {
<-ctx.Done()
_ = rw.Close()
}()
go func() { errCh <- state.readLoop(ctx) }()
go func() { errCh <- state.probeLoop(ctx) }()
go func() { errCh <- state.writeLoop(ctx) }()
err := <-errCh
cancel()
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
return err
}
type state struct {
rw io.ReadWriteCloser
cfg Config
now func() time.Time
out chan Message
mu sync.Mutex
pending map[uint64]time.Time
nextSeq uint64
failures int
}
func (s *state) readLoop(ctx context.Context) error {
for {
raw, err := readFrame(s.rw)
if err != nil {
return readLoopErr(ctx, err)
}
msg, err := parseMessage(raw)
if err != nil {
return err
}
if err := s.handleReadMessage(ctx, msg); err != nil {
return err
}
}
}
func readLoopErr(ctx context.Context, err error) error {
if ctx.Err() != nil {
return fmt.Errorf("read loop canceled: %w", ctx.Err())
}
return err
}
func (s *state) handleReadMessage(ctx context.Context, msg Message) error {
switch msg.Type {
case TypePing:
return s.enqueuePong(ctx, msg)
case TypePong:
s.handlePong(msg)
return nil
case TypeClose:
return ErrClosedByPeer
default:
return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
}
}
func (s *state) enqueuePong(ctx context.Context, ping Message) error {
err := s.enqueue(ctx, Message{
Version: ProtoVersion,
Type: TypePong,
Seq: ping.Seq,
SentUnixNano: ping.SentUnixNano,
})
if err != nil {
return readLoopErr(ctx, err)
}
return nil
}
func (s *state) probeLoop(ctx context.Context) error {
ticker := time.NewTicker(s.cfg.Interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("probe loop canceled: %w", ctx.Err())
case <-ticker.C:
if err := s.sendProbe(ctx); err != nil {
return err
}
}
}
}
func (s *state) sendProbe(ctx context.Context) error {
now := s.now()
s.mu.Lock()
missedNow := 0
for seq, sent := range s.pending {
if now.Sub(sent) < s.cfg.Timeout {
continue
}
delete(s.pending, seq)
s.failures++
missedNow++
}
missed := s.failures
if s.failures >= s.cfg.Failures {
s.mu.Unlock()
if missedNow > 0 && s.cfg.OnMissedPong != nil {
s.cfg.OnMissedPong(missed)
}
if s.cfg.OnUnhealthy != nil {
s.cfg.OnUnhealthy(missed)
}
return fmt.Errorf("%w: missed %d pong(s)", ErrUnhealthy, missed)
}
s.nextSeq++
seq := s.nextSeq
s.pending[seq] = now
s.mu.Unlock()
if missedNow > 0 && s.cfg.OnMissedPong != nil {
s.cfg.OnMissedPong(missed)
}
return s.enqueue(ctx, Message{
Version: ProtoVersion,
Type: TypePing,
Seq: seq,
SentUnixNano: now.UnixNano(),
})
}
func (s *state) handlePong(msg Message) {
now := s.now()
s.mu.Lock()
sent, ok := s.pending[msg.Seq]
if ok {
delete(s.pending, msg.Seq)
s.failures = 0
}
s.mu.Unlock()
if !ok || s.cfg.OnPong == nil {
return
}
s.cfg.OnPong(Health{
Seq: msg.Seq,
RTT: now.Sub(sent),
LastSeen: now,
})
}
func (s *state) enqueue(ctx context.Context, msg Message) error {
select {
case <-ctx.Done():
return fmt.Errorf("enqueue canceled: %w", ctx.Err())
case s.out <- msg:
return nil
}
}
func (s *state) writeLoop(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return fmt.Errorf("write loop canceled: %w", ctx.Err())
case msg := <-s.out:
if err := writeFrame(s.rw, msg); err != nil {
if ctx.Err() != nil {
return fmt.Errorf("write loop canceled: %w", ctx.Err())
}
return err
}
}
}
}
func parseMessage(raw []byte) (Message, error) {
var msg Message
if err := json.Unmarshal(raw, &msg); err != nil {
return Message{}, fmt.Errorf("parse control message: %w", err)
}
if msg.Version != ProtoVersion {
return Message{}, fmt.Errorf("%w: peer v%d, local v%d",
ErrProtocolVersion, msg.Version, ProtoVersion)
}
if msg.Type != TypePing && msg.Type != TypePong && msg.Type != TypeClose {
return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
}
return msg, nil
}
// SendClose sends a best-effort graceful close notification on the control stream.
func SendClose(w io.Writer) error {
return writeFrame(w, Message{Version: ProtoVersion, Type: TypeClose})
}
func writeFrame(w io.Writer, msg Message) error {
if err := framing.WriteJSON(w, msg, MaxMessageSize); err != nil {
return fmt.Errorf("control: %w", err)
}
return nil
}
func readFrame(r io.Reader) ([]byte, error) {
body, err := framing.ReadBytes(r, MaxMessageSize)
if err != nil {
return nil, fmt.Errorf("control: %w", err)
}
return body, nil
}