diff --git a/internal/config/config.go b/internal/config/config.go index e8a33dc..8df7058 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -23,6 +23,8 @@ import ( var ( // ErrConfigNotFound is returned when a config file path is set but the file does not exist. ErrConfigNotFound = errors.New("config file not found") + // ErrConfigInvalidUTF8 is returned when a config file is not valid UTF-8. + ErrConfigInvalidUTF8 = errors.New("config file is not valid UTF-8") // ErrCryptoKeyConflict is returned when both inline and file-backed keys are configured. ErrCryptoKeyConflict = errors.New("crypto.key and crypto.key_file cannot both be set") // ErrCryptoKeyFileEmpty is returned when crypto.key_file points to an empty file. @@ -178,7 +180,7 @@ func Load(path string) (File, error) { return File{}, fmt.Errorf("read config %s: %w", path, err) } if !utf8.Valid(data) { - return File{}, fmt.Errorf("parse config %s: file is not valid UTF-8", path) + return File{}, fmt.Errorf("parse config %s: %w", path, ErrConfigInvalidUTF8) } var f File if err := yaml.Unmarshal(data, &f); err != nil { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index d72a978..062788b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -4,7 +4,6 @@ import ( "errors" "os" "path/filepath" - "strings" "testing" "github.com/openlibrecommunity/olcrtc/internal/app/session" @@ -329,7 +328,7 @@ func TestLoadInvalidUTF8(t *testing.T) { } _, err := Load(path) - if err == nil || !strings.Contains(err.Error(), "file is not valid UTF-8") { + if !errors.Is(err, ErrConfigInvalidUTF8) { t.Fatalf("Load() error = %v, want invalid UTF-8 error", err) } } diff --git a/internal/control/control.go b/internal/control/control.go index de4f521..450f340 100644 --- a/internal/control/control.go +++ b/internal/control/control.go @@ -164,38 +164,52 @@ func (s *state) readLoop(ctx context.Context) error { for { raw, err := readFrame(s.rw) if err != nil { - if ctx.Err() != nil { - return fmt.Errorf("read loop canceled: %w", ctx.Err()) - } - return err + return readLoopErr(ctx, err) } msg, err := parseMessage(raw) if err != nil { return err } - switch msg.Type { - case TypePing: - if err := s.enqueue(ctx, Message{ - Version: ProtoVersion, - Type: TypePong, - Seq: msg.Seq, - SentUnixNano: msg.SentUnixNano, - }); err != nil { - if ctx.Err() != nil { - return fmt.Errorf("read loop canceled: %w", ctx.Err()) - } - return err - } - case TypePong: - s.handlePong(msg) - case TypeClose: - return ErrClosedByPeer - default: - return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type) + if err := s.handleReadMessage(ctx, msg); err != nil { + return err } } } +func readLoopErr(ctx context.Context, err error) error { + if ctx.Err() != nil { + return fmt.Errorf("read loop canceled: %w", ctx.Err()) + } + return err +} + +func (s *state) handleReadMessage(ctx context.Context, msg Message) error { + switch msg.Type { + case TypePing: + return s.enqueuePong(ctx, msg) + case TypePong: + s.handlePong(msg) + return nil + case TypeClose: + return ErrClosedByPeer + default: + return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type) + } +} + +func (s *state) enqueuePong(ctx context.Context, ping Message) error { + err := s.enqueue(ctx, Message{ + Version: ProtoVersion, + Type: TypePong, + Seq: ping.Seq, + SentUnixNano: ping.SentUnixNano, + }) + if err != nil { + return readLoopErr(ctx, err) + } + return nil +} + func (s *state) probeLoop(ctx context.Context) error { ticker := time.NewTicker(s.cfg.Interval) defer ticker.Stop()