refactor: update environment variable handling to use GetEnvMap for improved type safety and add feature extraction functionality

This commit is contained in:
keven1024
2026-04-04 21:42:25 +08:00
parent 95ab8b97da
commit c50bb5d0bf
5 changed files with 42 additions and 36 deletions

View File

@@ -32,7 +32,7 @@ func GetAbout(c *echo.Context) error {
return utils.HTTPSuccessHandler(c, map[string]any{
"bg_url": u.GetEnv("about.bg_url"),
"content": u.GetEnvMapString("about.content"),
"content": u.GetEnvMap("about.content"),
"email": u.GetEnv("about.email"),
"name": u.GetEnv("about.name"),
"url": u.GetEnv("about.url"),

View File

@@ -11,8 +11,8 @@ import (
func GetConfig(c *echo.Context) error {
return utils.HTTPSuccessHandler(c, map[string]any{
"site_title": u.GetEnvMapString("site.title"),
"site_desc": u.GetEnvMapString("site.desc"),
"site_title": u.GetEnvMap("site.title"),
"site_desc": u.GetEnvMap("site.desc"),
"site_url": u.GetEnv("site.url"),
"site_icon": u.GetEnvWithDefault("site.icon", "/logo.png"),
"site_bg_url": u.GetEnvWithDefault("site.bg_url", "https://img.fudaoyuan.icu/api/1/random/?scale_min=1.5&webp=true&md=false&format=302"),

View File

@@ -38,7 +38,7 @@ func TestGeneratePasswordHash(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 设置环境变量
u.InitEnv(u.EnvOption{
u.InitTestViper(u.EnvOption{
ConfigData: bytes.NewBuffer([]byte(fmt.Sprintf(`
share:
password_salt: %s

View File

@@ -13,33 +13,42 @@ var (
envOnce sync.Once
)
func InitEnv(props EnvOption) {
if v != nil {
return
func createViperInstance(props EnvOption) *viper.Viper {
instance := viper.New()
for _, viperConfigType := range props.ConfigType {
instance.SetConfigType(viperConfigType)
}
if props.ConfigData != nil {
instance.ReadConfig(props.ConfigData)
return instance
}
for _, name := range props.ConfigName {
instance.SetConfigName(name)
}
instance.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
for _, path := range props.ConfigPath {
instance.AddConfigPath(path)
}
instance.AutomaticEnv()
instance.WatchConfig()
if err := instance.ReadInConfig(); err != nil {
panic(err)
}
return instance
}
func InitTestViper(props EnvOption) *viper.Viper {
instance := createViperInstance(props)
v = instance
envOnce.Do(func() {}) // 消费 once防止 GetViperClient 覆盖已注入的实例
return instance
}
func GetViperClient() *viper.Viper {
envOnce.Do(func() {
v = viper.New()
for _, viperConfigType := range props.ConfigType {
v.SetConfigType(viperConfigType)
}
if props.ConfigData != nil {
v.ReadConfig(props.ConfigData)
return
}
for _, name := range props.ConfigName {
v.SetConfigName(name)
}
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
for _, path := range props.ConfigPath {
v.AddConfigPath(path)
}
v.AutomaticEnv()
v.WatchConfig()
err := v.ReadInConfig()
if err != nil {
panic(err)
}
v = createViperInstance(getEnvOptions())
})
return v
}
type Option interface {
@@ -75,8 +84,7 @@ func getEnvOptions(options ...Option) EnvOption {
func GetEnv(key string, options ...Option) string {
props := getEnvOptions(options...)
InitEnv(props)
value := v.GetString(key)
value := GetViperClient().GetString(key)
if value == "" && props.DefaultValue != "" {
return props.DefaultValue
@@ -88,12 +96,10 @@ func GetEnvWithDefault(key string, defaultValue string) string {
return GetEnv(key, WithDefaultValue(defaultValue))
}
func GetEnvMapString(key string) map[string]string {
props := getEnvOptions()
InitEnv(props)
return v.GetStringMapString(key)
func GetEnvMap(key string) map[string]any {
return GetViperClient().GetStringMap(key)
}
func SetEnv(key string, value string) {
v.Set(key, value)
GetViperClient().Set(key, value)
}

View File

@@ -22,7 +22,7 @@ func TestInitEnvAndGetEnv(t *testing.T) {
ConfigData: bytes.NewBufferString(jsonData),
ConfigType: []string{"json"},
}
utils.InitEnv(props)
utils.InitTestViper(props)
// GetEnv应能拿到值
val := utils.GetEnv("test.value")