feat(backend): add password hashing error handling and implement HTTP utility tests

This commit is contained in:
keven1024
2025-12-27 11:04:21 +08:00
parent 185f7a3503
commit 31c0736562
5 changed files with 113 additions and 49 deletions

View File

@@ -8,10 +8,14 @@ import (
"golang.org/x/crypto/argon2"
)
var (
ErrPasswordSaltNotSet = errors.New("请配置PASSWORD_SALT")
)
func GeneratePasswordHash(password string) (string, error) {
salt := utils.GetEnv("share.password_salt")
if salt == "" {
return "", errors.New("请配置PASSWORD_SALT")
return "", ErrPasswordSaltNotSet
}
hash := argon2.IDKey([]byte(password), []byte(salt), 1, 64*1024, 4, 32)
return fmt.Sprintf("%x", hash), nil

View File

@@ -6,6 +6,8 @@ import (
"net/http/httptest"
"testing"
"backend/internal/utils"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
@@ -17,7 +19,7 @@ func TestHTTPSuccessHandler(t *testing.T) {
c := e.NewContext(req, rec)
data := map[string]interface{}{"result": "success"}
err := HTTPSuccessHandler(c, data)
err := utils.HTTPSuccessHandler(c, data)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
@@ -40,7 +42,7 @@ func TestHTTPErrorHandler(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := HTTPErrorHandler(c, assert.AnError)
err := utils.HTTPErrorHandler(c, assert.AnError)
assert.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, rec.Code)

View File

@@ -1,30 +1,31 @@
package utils
import (
"os"
"bytes"
"fmt"
"testing"
"backend/internal/utils"
u "pkg/utils"
"github.com/stretchr/testify/assert"
)
func TestGeneratePasswordHash(t *testing.T) {
// 保存原始环境变量
originalSalt := os.Getenv("share.password_salt")
defer os.Setenv("share.password_salt", originalSalt)
tests := []struct {
name string
password string
salt string
expectError bool
errorMsg string
err error
}{
{
name: "share.password_salt未配置",
password: "testpassword",
salt: "",
expectError: true,
errorMsg: "请配置share.password_salt",
err: utils.ErrPasswordSaltNotSet,
},
{
name: "正常生成哈希",
@@ -37,21 +38,23 @@ func TestGeneratePasswordHash(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 设置环境变量
if tt.salt != "" {
os.Setenv("share.password_salt", tt.salt)
} else {
os.Unsetenv("share.password_salt")
}
u.InitEnv(u.EnvOption{
ConfigData: bytes.NewBuffer([]byte(fmt.Sprintf(`
share:
password_salt: %s
`, tt.salt))),
})
u.SetEnv("share.password_salt", tt.salt)
hash, err := GeneratePasswordHash(tt.password)
hash, err := utils.GeneratePasswordHash(tt.password)
if tt.expectError {
if err == nil {
t.Errorf("期望错误,但得到了 nil")
return
}
if err.Error() != tt.errorMsg {
t.Errorf("期望错误信息 '%s',但得到了 '%s'", tt.errorMsg, err.Error())
if err != tt.err {
t.Errorf("期望错误信息 '%s',但得到了 '%s'", tt.err.Error(), err.Error())
}
return
}

View File

@@ -1,53 +1,99 @@
package utils
import (
"io"
"strings"
"sync"
"github.com/spf13/viper"
)
var v *viper.Viper
var (
v *viper.Viper
envOnce sync.Once
)
func init() {
InitEnv()
}
func InitEnv() {
func InitEnv(props EnvOption) {
if v != nil {
return
}
v = viper.New()
v.SetConfigName("config.yaml")
v.SetConfigType("yaml")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AddConfigPath(".")
v.AddConfigPath("../")
v.AutomaticEnv()
v.WatchConfig()
err := v.ReadInConfig()
if err != nil {
panic(err)
// if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
// // 只有当错误不是"配置文件未找到"时才 panic
// panic(err)
// }
envOnce.Do(func() {
v = viper.New()
if props.ConfigData != nil {
v.ReadConfig(props.ConfigData)
return
}
for _, name := range props.ConfigName {
v.SetConfigName(name)
}
for _, viperConfigType := range props.ConfigType {
v.SetConfigType(viperConfigType)
}
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
for _, path := range props.ConfigPath {
v.AddConfigPath(path)
}
v.AutomaticEnv()
v.WatchConfig()
err := v.ReadInConfig()
if err != nil {
panic(err)
}
})
}
type Option interface {
applyTo(*EnvOption)
}
type EnvOption struct {
DefaultValue string
ConfigPath []string
ConfigName []string
ConfigType []string
ConfigData io.Reader // 测试环境使用
}
type WithDefaultValue string
func (o WithDefaultValue) applyTo(props *EnvOption) {
props.DefaultValue = string(o)
}
func getEnvOptions(options ...Option) EnvOption {
props := EnvOption{
DefaultValue: "",
ConfigPath: []string{".", "../"},
ConfigName: []string{"config"},
ConfigType: []string{"yaml"},
}
for _, option := range options {
option.applyTo(&props)
}
return props
}
func GetEnv(key string) string {
InitEnv()
return v.GetString(key)
}
func GetEnvWithDefault(key string, defaultValue string) string {
func GetEnv(key string, options ...Option) string {
props := getEnvOptions(options...)
InitEnv(props)
value := v.GetString(key)
if value == "" {
return defaultValue
if value == "" && props.DefaultValue != "" {
return props.DefaultValue
}
return value
}
func GetEnvWithDefault(key string, defaultValue string) string {
return GetEnv(key, WithDefaultValue(defaultValue))
}
func GetEnvMapString(key string) map[string]string {
InitEnv()
props := getEnvOptions()
InitEnv(props)
return v.GetStringMapString(key)
}
func SetEnv(key string, value string) {
v.Set(key, value)
}

View File

@@ -2,12 +2,16 @@ package utils
import (
"context"
"sync"
"github.com/redis/go-redis/v9"
)
var rdb *redis.Client = InitRedis()
var ctx = context.Background()
var (
rdb *redis.Client
ctx = context.Background()
onceRedis sync.Once
)
func InitRedis() *redis.Client {
opt, err := redis.ParseURL(GetEnv("redis.url"))
@@ -18,5 +22,10 @@ func InitRedis() *redis.Client {
}
func GetRedisClient() (*redis.Client, context.Context) {
onceRedis.Do(func() {
if rdb == nil {
rdb = InitRedis()
}
})
return rdb, ctx
}