diff --git a/backend/internal/controllers/about.go b/backend/internal/controllers/about.go index a0ccb0b..179f0cc 100644 --- a/backend/internal/controllers/about.go +++ b/backend/internal/controllers/about.go @@ -32,7 +32,7 @@ func GetAbout(c *echo.Context) error { return utils.HTTPSuccessHandler(c, map[string]any{ "bg_url": u.GetEnv("about.bg_url"), - "content": u.GetEnvMapString("about.content"), + "content": u.GetEnvMap("about.content"), "email": u.GetEnv("about.email"), "name": u.GetEnv("about.name"), "url": u.GetEnv("about.url"), diff --git a/backend/internal/controllers/config.go b/backend/internal/controllers/config.go index d4dda68..4e37500 100644 --- a/backend/internal/controllers/config.go +++ b/backend/internal/controllers/config.go @@ -11,8 +11,8 @@ import ( func GetConfig(c *echo.Context) error { return utils.HTTPSuccessHandler(c, map[string]any{ - "site_title": u.GetEnvMapString("site.title"), - "site_desc": u.GetEnvMapString("site.desc"), + "site_title": u.GetEnvMap("site.title"), + "site_desc": u.GetEnvMap("site.desc"), "site_url": u.GetEnv("site.url"), "site_icon": u.GetEnvWithDefault("site.icon", "/logo.png"), "site_bg_url": u.GetEnvWithDefault("site.bg_url", "https://img.fudaoyuan.icu/api/1/random/?scale_min=1.5&webp=true&md=false&format=302"), diff --git a/backend/test/utils/password_test.go b/backend/test/utils/password_test.go index bcec6f3..3ca0a66 100644 --- a/backend/test/utils/password_test.go +++ b/backend/test/utils/password_test.go @@ -38,7 +38,7 @@ func TestGeneratePasswordHash(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 设置环境变量 - u.InitEnv(u.EnvOption{ + u.InitTestViper(u.EnvOption{ ConfigData: bytes.NewBuffer([]byte(fmt.Sprintf(` share: password_salt: %s diff --git a/pkg/utils/env.go b/pkg/utils/env.go index 86ae50c..ba50d36 100644 --- a/pkg/utils/env.go +++ b/pkg/utils/env.go @@ -13,33 +13,42 @@ var ( envOnce sync.Once ) -func InitEnv(props EnvOption) { - if v != nil { - return +func createViperInstance(props EnvOption) *viper.Viper { + instance := viper.New() + for _, viperConfigType := range props.ConfigType { + instance.SetConfigType(viperConfigType) } + if props.ConfigData != nil { + instance.ReadConfig(props.ConfigData) + return instance + } + for _, name := range props.ConfigName { + instance.SetConfigName(name) + } + instance.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + for _, path := range props.ConfigPath { + instance.AddConfigPath(path) + } + instance.AutomaticEnv() + instance.WatchConfig() + if err := instance.ReadInConfig(); err != nil { + panic(err) + } + return instance +} + +func InitTestViper(props EnvOption) *viper.Viper { + instance := createViperInstance(props) + v = instance + envOnce.Do(func() {}) // 消费 once,防止 GetViperClient 覆盖已注入的实例 + return instance +} + +func GetViperClient() *viper.Viper { envOnce.Do(func() { - v = viper.New() - for _, viperConfigType := range props.ConfigType { - v.SetConfigType(viperConfigType) - } - if props.ConfigData != nil { - v.ReadConfig(props.ConfigData) - return - } - for _, name := range props.ConfigName { - v.SetConfigName(name) - } - v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - for _, path := range props.ConfigPath { - v.AddConfigPath(path) - } - v.AutomaticEnv() - v.WatchConfig() - err := v.ReadInConfig() - if err != nil { - panic(err) - } + v = createViperInstance(getEnvOptions()) }) + return v } type Option interface { @@ -75,8 +84,7 @@ func getEnvOptions(options ...Option) EnvOption { func GetEnv(key string, options ...Option) string { props := getEnvOptions(options...) - InitEnv(props) - value := v.GetString(key) + value := GetViperClient().GetString(key) if value == "" && props.DefaultValue != "" { return props.DefaultValue @@ -88,12 +96,10 @@ func GetEnvWithDefault(key string, defaultValue string) string { return GetEnv(key, WithDefaultValue(defaultValue)) } -func GetEnvMapString(key string) map[string]string { - props := getEnvOptions() - InitEnv(props) - return v.GetStringMapString(key) +func GetEnvMap(key string) map[string]any { + return GetViperClient().GetStringMap(key) } func SetEnv(key string, value string) { - v.Set(key, value) + GetViperClient().Set(key, value) } diff --git a/pkg/utils/test/env_test.go b/pkg/utils/test/env_test.go index 8efb705..f4038bb 100644 --- a/pkg/utils/test/env_test.go +++ b/pkg/utils/test/env_test.go @@ -22,7 +22,7 @@ func TestInitEnvAndGetEnv(t *testing.T) { ConfigData: bytes.NewBufferString(jsonData), ConfigType: []string{"json"}, } - utils.InitEnv(props) + utils.InitTestViper(props) // GetEnv应能拿到值 val := utils.GetEnv("test.value")