mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 15:13:40 +00:00
350 lines
8.5 KiB
Go
350 lines
8.5 KiB
Go
// Package control implements the post-handshake control stream protocol.
|
|
//
|
|
// The control stream is the first smux stream after the olcrtc handshake. It
|
|
// stays inside the encrypted muxconn path, so ping/pong proves that the actual
|
|
// tunnel path still round-trips, not merely that the provider connection is up.
|
|
//
|
|
// Wire format matches the handshake framing: a 4-byte big-endian length
|
|
// followed by a JSON message.
|
|
//
|
|
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
|
|
package control
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/openlibrecommunity/olcrtc/internal/framing"
|
|
)
|
|
|
|
const (
|
|
// ProtoVersion identifies the control stream wire format.
|
|
ProtoVersion = 1
|
|
// MaxMessageSize caps one control frame.
|
|
MaxMessageSize = 16 * 1024
|
|
// DefaultInterval is the default interval between ping probes.
|
|
DefaultInterval = 10 * time.Second
|
|
// DefaultTimeout is the default time to wait for a pong.
|
|
DefaultTimeout = 5 * time.Second
|
|
// DefaultFailures is the default number of consecutive missed pongs before
|
|
// the stream is marked unhealthy.
|
|
DefaultFailures = 3
|
|
)
|
|
|
|
// MsgType labels a control message.
|
|
type MsgType string
|
|
|
|
const (
|
|
// TypePing is sent periodically to prove control-stream liveness.
|
|
TypePing MsgType = "CONTROL_PING"
|
|
// TypePong replies to a ping with the same sequence and timestamp.
|
|
TypePong MsgType = "CONTROL_PONG"
|
|
// TypeClose tells the peer this control session is intentionally closing.
|
|
TypeClose MsgType = "CONTROL_CLOSE"
|
|
)
|
|
|
|
var (
|
|
// ErrUnhealthy is returned when the stream misses too many pong replies.
|
|
ErrUnhealthy = errors.New("control stream unhealthy")
|
|
// ErrClosedByPeer is returned when the peer gracefully closes the control session.
|
|
ErrClosedByPeer = errors.New("control stream closed by peer")
|
|
// ErrProtocolVersion is returned when the peer announces an incompatible version.
|
|
ErrProtocolVersion = errors.New("incompatible control protocol version")
|
|
// ErrUnexpectedMessage is returned for unknown or malformed control message types.
|
|
ErrUnexpectedMessage = errors.New("unexpected control message")
|
|
// ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize].
|
|
ErrFrameTooLarge = framing.ErrFrameTooLarge
|
|
)
|
|
|
|
// Message is one control-stream frame.
|
|
type Message struct {
|
|
Version int `json:"version"`
|
|
Type MsgType `json:"type"`
|
|
Seq uint64 `json:"seq,omitempty"`
|
|
SentUnixNano int64 `json:"sent_unix_nano,omitempty"`
|
|
}
|
|
|
|
// Health is reported when a ping round trip completes.
|
|
type Health struct {
|
|
Seq uint64
|
|
RTT time.Duration
|
|
LastSeen time.Time
|
|
}
|
|
|
|
// Status is a point-in-time view of control-stream health maintained by
|
|
// callers that embed the control loop.
|
|
type Status struct {
|
|
SessionID string
|
|
LastPong time.Time
|
|
LastRTT time.Duration
|
|
MissedPongs int
|
|
Reconnects uint64
|
|
UnhealthyEvents uint64
|
|
LastUnhealthy time.Time
|
|
}
|
|
|
|
// Config controls the liveness loop.
|
|
type Config struct {
|
|
Interval time.Duration
|
|
Timeout time.Duration
|
|
Failures int
|
|
|
|
// OnPong is called after a matching pong is received.
|
|
OnPong func(Health)
|
|
// OnMissedPong is called when one or more outstanding pongs time out.
|
|
OnMissedPong func(missed int)
|
|
// OnUnhealthy is called before Run returns [ErrUnhealthy].
|
|
OnUnhealthy func(missed int)
|
|
}
|
|
|
|
func (cfg Config) withDefaults() Config {
|
|
if cfg.Interval <= 0 {
|
|
cfg.Interval = DefaultInterval
|
|
}
|
|
if cfg.Timeout <= 0 {
|
|
cfg.Timeout = DefaultTimeout
|
|
}
|
|
if cfg.Failures <= 0 {
|
|
cfg.Failures = DefaultFailures
|
|
}
|
|
return cfg
|
|
}
|
|
|
|
// Run drives bidirectional ping/pong liveness until ctx is canceled, rw closes,
|
|
// or the configured failure threshold is reached.
|
|
func Run(ctx context.Context, rw io.ReadWriteCloser, cfg Config) error {
|
|
cfg = cfg.withDefaults()
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
state := &state{
|
|
rw: rw,
|
|
cfg: cfg,
|
|
pending: make(map[uint64]time.Time),
|
|
now: time.Now,
|
|
out: make(chan Message, 16),
|
|
}
|
|
|
|
errCh := make(chan error, 3)
|
|
go func() {
|
|
<-ctx.Done()
|
|
_ = rw.Close()
|
|
}()
|
|
go func() { errCh <- state.readLoop(ctx) }()
|
|
go func() { errCh <- state.probeLoop(ctx) }()
|
|
go func() { errCh <- state.writeLoop(ctx) }()
|
|
|
|
err := <-errCh
|
|
cancel()
|
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
type state struct {
|
|
rw io.ReadWriteCloser
|
|
cfg Config
|
|
now func() time.Time
|
|
|
|
out chan Message
|
|
|
|
mu sync.Mutex
|
|
pending map[uint64]time.Time
|
|
nextSeq uint64
|
|
failures int
|
|
}
|
|
|
|
func (s *state) readLoop(ctx context.Context) error {
|
|
for {
|
|
raw, err := readFrame(s.rw)
|
|
if err != nil {
|
|
return readLoopErr(ctx, err)
|
|
}
|
|
msg, err := parseMessage(raw)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := s.handleReadMessage(ctx, msg); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func readLoopErr(ctx context.Context, err error) error {
|
|
if ctx.Err() != nil {
|
|
return fmt.Errorf("read loop canceled: %w", ctx.Err())
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *state) handleReadMessage(ctx context.Context, msg Message) error {
|
|
switch msg.Type {
|
|
case TypePing:
|
|
return s.enqueuePong(ctx, msg)
|
|
case TypePong:
|
|
s.handlePong(msg)
|
|
return nil
|
|
case TypeClose:
|
|
return ErrClosedByPeer
|
|
default:
|
|
return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
|
|
}
|
|
}
|
|
|
|
func (s *state) enqueuePong(ctx context.Context, ping Message) error {
|
|
err := s.enqueue(ctx, Message{
|
|
Version: ProtoVersion,
|
|
Type: TypePong,
|
|
Seq: ping.Seq,
|
|
SentUnixNano: ping.SentUnixNano,
|
|
})
|
|
if err != nil {
|
|
return readLoopErr(ctx, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *state) probeLoop(ctx context.Context) error {
|
|
ticker := time.NewTicker(s.cfg.Interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("probe loop canceled: %w", ctx.Err())
|
|
case <-ticker.C:
|
|
if err := s.sendProbe(ctx); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *state) sendProbe(ctx context.Context) error {
|
|
now := s.now()
|
|
|
|
s.mu.Lock()
|
|
missedNow := 0
|
|
for seq, sent := range s.pending {
|
|
if now.Sub(sent) < s.cfg.Timeout {
|
|
continue
|
|
}
|
|
delete(s.pending, seq)
|
|
s.failures++
|
|
missedNow++
|
|
}
|
|
missed := s.failures
|
|
if s.failures >= s.cfg.Failures {
|
|
s.mu.Unlock()
|
|
if missedNow > 0 && s.cfg.OnMissedPong != nil {
|
|
s.cfg.OnMissedPong(missed)
|
|
}
|
|
if s.cfg.OnUnhealthy != nil {
|
|
s.cfg.OnUnhealthy(missed)
|
|
}
|
|
return fmt.Errorf("%w: missed %d pong(s)", ErrUnhealthy, missed)
|
|
}
|
|
|
|
s.nextSeq++
|
|
seq := s.nextSeq
|
|
s.pending[seq] = now
|
|
s.mu.Unlock()
|
|
if missedNow > 0 && s.cfg.OnMissedPong != nil {
|
|
s.cfg.OnMissedPong(missed)
|
|
}
|
|
|
|
return s.enqueue(ctx, Message{
|
|
Version: ProtoVersion,
|
|
Type: TypePing,
|
|
Seq: seq,
|
|
SentUnixNano: now.UnixNano(),
|
|
})
|
|
}
|
|
|
|
func (s *state) handlePong(msg Message) {
|
|
now := s.now()
|
|
|
|
s.mu.Lock()
|
|
sent, ok := s.pending[msg.Seq]
|
|
if ok {
|
|
delete(s.pending, msg.Seq)
|
|
s.failures = 0
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
if !ok || s.cfg.OnPong == nil {
|
|
return
|
|
}
|
|
s.cfg.OnPong(Health{
|
|
Seq: msg.Seq,
|
|
RTT: now.Sub(sent),
|
|
LastSeen: now,
|
|
})
|
|
}
|
|
|
|
func (s *state) enqueue(ctx context.Context, msg Message) error {
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("enqueue canceled: %w", ctx.Err())
|
|
case s.out <- msg:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (s *state) writeLoop(ctx context.Context) error {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("write loop canceled: %w", ctx.Err())
|
|
case msg := <-s.out:
|
|
if err := writeFrame(s.rw, msg); err != nil {
|
|
if ctx.Err() != nil {
|
|
return fmt.Errorf("write loop canceled: %w", ctx.Err())
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func parseMessage(raw []byte) (Message, error) {
|
|
var msg Message
|
|
if err := json.Unmarshal(raw, &msg); err != nil {
|
|
return Message{}, fmt.Errorf("parse control message: %w", err)
|
|
}
|
|
if msg.Version != ProtoVersion {
|
|
return Message{}, fmt.Errorf("%w: peer v%d, local v%d",
|
|
ErrProtocolVersion, msg.Version, ProtoVersion)
|
|
}
|
|
if msg.Type != TypePing && msg.Type != TypePong && msg.Type != TypeClose {
|
|
return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
|
|
}
|
|
return msg, nil
|
|
}
|
|
|
|
// SendClose sends a best-effort graceful close notification on the control stream.
|
|
func SendClose(w io.Writer) error {
|
|
return writeFrame(w, Message{Version: ProtoVersion, Type: TypeClose})
|
|
}
|
|
|
|
func writeFrame(w io.Writer, msg Message) error {
|
|
if err := framing.WriteJSON(w, msg, MaxMessageSize); err != nil {
|
|
return fmt.Errorf("control: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func readFrame(r io.Reader) ([]byte, error) {
|
|
body, err := framing.ReadBytes(r, MaxMessageSize)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("control: %w", err)
|
|
}
|
|
return body, nil
|
|
}
|