From 31c0736562c9d5902a2170d90a1fc0443e75d016 Mon Sep 17 00:00:00 2001 From: keven1024 Date: Sat, 27 Dec 2025 11:04:21 +0800 Subject: [PATCH] feat(backend): add password hashing error handling and implement HTTP utility tests --- backend/internal/utils/password.go | 6 +- .../utils/http_result_test.go | 6 +- .../{internal => test}/utils/password_test.go | 31 ++--- pkg/utils/env.go | 106 +++++++++++++----- pkg/utils/redis.go | 13 ++- 5 files changed, 113 insertions(+), 49 deletions(-) rename backend/{internal => test}/utils/http_result_test.go (91%) rename backend/{internal => test}/utils/password_test.go (72%) diff --git a/backend/internal/utils/password.go b/backend/internal/utils/password.go index 9b0b3ec..b1bd277 100644 --- a/backend/internal/utils/password.go +++ b/backend/internal/utils/password.go @@ -8,10 +8,14 @@ import ( "golang.org/x/crypto/argon2" ) +var ( + ErrPasswordSaltNotSet = errors.New("请配置PASSWORD_SALT") +) + func GeneratePasswordHash(password string) (string, error) { salt := utils.GetEnv("share.password_salt") if salt == "" { - return "", errors.New("请配置PASSWORD_SALT") + return "", ErrPasswordSaltNotSet } hash := argon2.IDKey([]byte(password), []byte(salt), 1, 64*1024, 4, 32) return fmt.Sprintf("%x", hash), nil diff --git a/backend/internal/utils/http_result_test.go b/backend/test/utils/http_result_test.go similarity index 91% rename from backend/internal/utils/http_result_test.go rename to backend/test/utils/http_result_test.go index 8656686..a240261 100644 --- a/backend/internal/utils/http_result_test.go +++ b/backend/test/utils/http_result_test.go @@ -6,6 +6,8 @@ import ( "net/http/httptest" "testing" + "backend/internal/utils" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) @@ -17,7 +19,7 @@ func TestHTTPSuccessHandler(t *testing.T) { c := e.NewContext(req, rec) data := map[string]interface{}{"result": "success"} - err := HTTPSuccessHandler(c, data) + err := utils.HTTPSuccessHandler(c, data) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) @@ -40,7 +42,7 @@ func TestHTTPErrorHandler(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - err := HTTPErrorHandler(c, assert.AnError) + err := utils.HTTPErrorHandler(c, assert.AnError) assert.NoError(t, err) assert.Equal(t, http.StatusBadRequest, rec.Code) diff --git a/backend/internal/utils/password_test.go b/backend/test/utils/password_test.go similarity index 72% rename from backend/internal/utils/password_test.go rename to backend/test/utils/password_test.go index 6584560..bcec6f3 100644 --- a/backend/internal/utils/password_test.go +++ b/backend/test/utils/password_test.go @@ -1,30 +1,31 @@ package utils import ( - "os" + "bytes" + "fmt" "testing" + "backend/internal/utils" + u "pkg/utils" + "github.com/stretchr/testify/assert" ) func TestGeneratePasswordHash(t *testing.T) { - // 保存原始环境变量 - originalSalt := os.Getenv("share.password_salt") - defer os.Setenv("share.password_salt", originalSalt) tests := []struct { name string password string salt string expectError bool - errorMsg string + err error }{ { name: "share.password_salt未配置", password: "testpassword", salt: "", expectError: true, - errorMsg: "请配置share.password_salt", + err: utils.ErrPasswordSaltNotSet, }, { name: "正常生成哈希", @@ -37,21 +38,23 @@ func TestGeneratePasswordHash(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 设置环境变量 - if tt.salt != "" { - os.Setenv("share.password_salt", tt.salt) - } else { - os.Unsetenv("share.password_salt") - } + u.InitEnv(u.EnvOption{ + ConfigData: bytes.NewBuffer([]byte(fmt.Sprintf(` + share: + password_salt: %s + `, tt.salt))), + }) + u.SetEnv("share.password_salt", tt.salt) - hash, err := GeneratePasswordHash(tt.password) + hash, err := utils.GeneratePasswordHash(tt.password) if tt.expectError { if err == nil { t.Errorf("期望错误,但得到了 nil") return } - if err.Error() != tt.errorMsg { - t.Errorf("期望错误信息 '%s',但得到了 '%s'", tt.errorMsg, err.Error()) + if err != tt.err { + t.Errorf("期望错误信息 '%s',但得到了 '%s'", tt.err.Error(), err.Error()) } return } diff --git a/pkg/utils/env.go b/pkg/utils/env.go index 0fba7c2..d4b8f57 100644 --- a/pkg/utils/env.go +++ b/pkg/utils/env.go @@ -1,53 +1,99 @@ package utils import ( + "io" "strings" + "sync" "github.com/spf13/viper" ) -var v *viper.Viper +var ( + v *viper.Viper + envOnce sync.Once +) -func init() { - InitEnv() -} - -func InitEnv() { +func InitEnv(props EnvOption) { if v != nil { return } - v = viper.New() - v.SetConfigName("config.yaml") - v.SetConfigType("yaml") - v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - v.AddConfigPath(".") - v.AddConfigPath("../") - v.AutomaticEnv() - v.WatchConfig() - err := v.ReadInConfig() - if err != nil { - panic(err) - // if _, ok := err.(viper.ConfigFileNotFoundError); !ok { - // // 只有当错误不是"配置文件未找到"时才 panic - // panic(err) - // } + envOnce.Do(func() { + v = viper.New() + if props.ConfigData != nil { + v.ReadConfig(props.ConfigData) + return + } + for _, name := range props.ConfigName { + v.SetConfigName(name) + } + for _, viperConfigType := range props.ConfigType { + v.SetConfigType(viperConfigType) + } + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + for _, path := range props.ConfigPath { + v.AddConfigPath(path) + } + v.AutomaticEnv() + v.WatchConfig() + err := v.ReadInConfig() + if err != nil { + panic(err) + } + }) +} + +type Option interface { + applyTo(*EnvOption) +} + +type EnvOption struct { + DefaultValue string + ConfigPath []string + ConfigName []string + ConfigType []string + ConfigData io.Reader // 测试环境使用 +} + +type WithDefaultValue string + +func (o WithDefaultValue) applyTo(props *EnvOption) { + props.DefaultValue = string(o) +} + +func getEnvOptions(options ...Option) EnvOption { + props := EnvOption{ + DefaultValue: "", + ConfigPath: []string{".", "../"}, + ConfigName: []string{"config"}, + ConfigType: []string{"yaml"}, } + for _, option := range options { + option.applyTo(&props) + } + return props } -func GetEnv(key string) string { - InitEnv() - return v.GetString(key) -} - -func GetEnvWithDefault(key string, defaultValue string) string { +func GetEnv(key string, options ...Option) string { + props := getEnvOptions(options...) + InitEnv(props) value := v.GetString(key) - if value == "" { - return defaultValue + + if value == "" && props.DefaultValue != "" { + return props.DefaultValue } return value } +func GetEnvWithDefault(key string, defaultValue string) string { + return GetEnv(key, WithDefaultValue(defaultValue)) +} + func GetEnvMapString(key string) map[string]string { - InitEnv() + props := getEnvOptions() + InitEnv(props) return v.GetStringMapString(key) } + +func SetEnv(key string, value string) { + v.Set(key, value) +} diff --git a/pkg/utils/redis.go b/pkg/utils/redis.go index 1f8f38c..093f9dc 100644 --- a/pkg/utils/redis.go +++ b/pkg/utils/redis.go @@ -2,12 +2,16 @@ package utils import ( "context" + "sync" "github.com/redis/go-redis/v9" ) -var rdb *redis.Client = InitRedis() -var ctx = context.Background() +var ( + rdb *redis.Client + ctx = context.Background() + onceRedis sync.Once +) func InitRedis() *redis.Client { opt, err := redis.ParseURL(GetEnv("redis.url")) @@ -18,5 +22,10 @@ func InitRedis() *redis.Client { } func GetRedisClient() (*redis.Client, context.Context) { + onceRedis.Do(func() { + if rdb == nil { + rdb = InitRedis() + } + }) return rdb, ctx }