From 1298bf79d3d2473db24bad81cc36fc35b26912d8 Mon Sep 17 00:00:00 2001 From: keven1024 Date: Mon, 6 Apr 2026 11:49:32 +0800 Subject: [PATCH] refactor(backend): enhance DownloadShare functionality with context management and improve Redis share info handling --- backend/internal/controllers/download.go | 99 ++++++++++++------------ pkg/models/file.go | 20 +++-- pkg/models/share.go | 19 +++-- pkg/models/stat.go | 22 +++--- worker/internal/services/file.go | 4 +- 5 files changed, 94 insertions(+), 70 deletions(-) diff --git a/backend/internal/controllers/download.go b/backend/internal/controllers/download.go index 8f3b1f5..b95ab1b 100644 --- a/backend/internal/controllers/download.go +++ b/backend/internal/controllers/download.go @@ -2,6 +2,7 @@ package controllers import ( "backend/internal/utils" + "context" "fmt" "pkg/models" u "pkg/utils" @@ -9,6 +10,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/labstack/echo/v5" + "github.com/samber/lo" "github.com/spf13/cast" ) @@ -84,64 +86,65 @@ func VaildateShare(c *echo.Context) error { return utils.HTTPErrorHandler(c, ErrInvalidSharePassword) } } - // 如果下载次数为0,则设置为-1 防止空值问题 - if shareInfo.ViewNum < 1 { - return utils.HTTPErrorHandler(c, ErrInsufficientDownloadQuota) - } - downloadWindow := u.GetEnvWithDefault("share.download_window", "12") - token := jwt.NewWithClaims(jwt.SigningMethodHS256, DownloadShareClaims{ - ShareId: r.ShareId, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(cast.ToDuration(downloadWindow + "h"))), - }, - }) + return u.WithLocker(context.Background(), "015:shareInfoMap:"+r.ShareId, 0, func(ctx context.Context) error { + shareInfo, err := models.GetRedisShareInfo(r.ShareId) + if err != nil || shareInfo == nil { + return utils.HTTPErrorHandler(c, lo.Ternary(err != nil, err, ErrShareNotFound)) + } + if shareInfo.ViewNum < 1 { + return utils.HTTPErrorHandler(c, ErrInsufficientDownloadQuota) + } + downloadWindow := u.GetEnvWithDefault("share.download_window", "12") + token := jwt.NewWithClaims(jwt.SigningMethodHS256, DownloadShareClaims{ + ShareId: r.ShareId, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(cast.ToDuration(downloadWindow + "h"))), + }, + }) - // Sign and get the complete encoded token as a string using the secret - downloadToken, err := token.SignedString([]byte(u.GetEnv("share.download_secret"))) - if err != nil { - return utils.HTTPErrorHandler(c, err) - } - if shareInfo.Type == models.ShareTypeFile { - fileInfo, err := models.GetRedisFileInfo(shareInfo.Data) + // Sign and get the complete encoded token as a string using the secret + downloadToken, err := token.SignedString([]byte(u.GetEnv("share.download_secret"))) if err != nil { return utils.HTTPErrorHandler(c, err) } - if fileInfo == nil { - return utils.HTTPErrorHandler(c, ErrShareFileNotFound) + if shareInfo.Type == models.ShareTypeFile { + fileInfo, err := models.GetRedisFileInfo(shareInfo.Data) + if err != nil { + return utils.HTTPErrorHandler(c, err) + } + if fileInfo == nil { + return utils.HTTPErrorHandler(c, ErrShareFileNotFound) + } + if fileInfo.FileType != models.FileTypeUpload { + return utils.HTTPErrorHandler(c, ErrInvalidShareFileState) + } } - if fileInfo.FileType != models.FileTypeUpload { - return utils.HTTPErrorHandler(c, ErrInvalidShareFileState) + // download_nums 必须放在创建token的时候减掉,不然多线程下载会导致多次减掉 + err = models.SetRedisShareInfo(r.ShareId, func(shareInfo *models.RedisShareInfo) *models.RedisShareInfo { + shareInfo.ViewNum -= 1 + return shareInfo + }) + if err != nil { + return utils.HTTPErrorHandler(c, err) } - } - // download_nums 必须放在创建token的时候减掉,不然多线程下载会导致多次减掉 - latestViewNum := shareInfo.ViewNum - 1 - // 如果下载次数为0,则设置为-1 防止空值问题 - if latestViewNum < 1 { - latestViewNum = -1 - } - err = models.SetRedisShareInfo(r.ShareId, models.RedisShareInfo{ - ViewNum: latestViewNum, - }) - if err != nil { - return utils.HTTPErrorHandler(c, err) - } - // 统计分享数 - currentDate := time.Now().Format("2006-01-02") - err = models.SetRedisStat(currentDate, func(stat *models.StatData) *models.StatData { - stat.DownloadNum += 1 - return stat - }) - if err != nil { - return utils.HTTPErrorHandler(c, err) - } + // 统计分享数 + currentDate := time.Now().Format("2006-01-02") + err = models.SetRedisStat(currentDate, func(stat *models.StatData) *models.StatData { + stat.DownloadNum += 1 + return stat + }) + if err != nil { + return utils.HTTPErrorHandler(c, err) + } - if shareInfo.Type == models.ShareTypeFile { + if shareInfo.Type == models.ShareTypeFile { + return utils.HTTPSuccessHandler(c, map[string]any{ + "token": downloadToken, + }) + } return utils.HTTPSuccessHandler(c, map[string]any{ "token": downloadToken, }) - } - return utils.HTTPSuccessHandler(c, map[string]any{ - "token": downloadToken, }) } diff --git a/pkg/models/file.go b/pkg/models/file.go index ff36f7a..10d5adc 100644 --- a/pkg/models/file.go +++ b/pkg/models/file.go @@ -3,8 +3,10 @@ package models import ( "encoding/json" "pkg/utils" + "time" "github.com/redis/rueidis" + "github.com/spf13/cast" ) type FileInfo struct { @@ -25,6 +27,7 @@ type RedisFileInfo struct { FileInfo FileType FileType `json:"type"` CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` Expire int64 `json:"expire"` // 只有上传文件(init)的时候有这个字段 } @@ -44,21 +47,28 @@ func GetRedisFileInfo(fileId string) (*RedisFileInfo, error) { return &fileInfoData, nil } -func SetRedisFileInfo(fileId string, handler func(fileInfo *RedisFileInfo) *RedisFileInfo) error { +func SetRedisFileInfo(fileId string, handler func(fileInfo *RedisFileInfo) *RedisFileInfo) (*RedisFileInfo, error) { rdb, ctx := utils.GetRedisClient() old_fileInfo, err := GetRedisFileInfo(fileId) if err != nil { - return err + return nil, err } if old_fileInfo == nil { - old_fileInfo = &RedisFileInfo{} + old_fileInfo = &RedisFileInfo{ + CreatedAt: time.Now().Unix(), + Expire: cast.ToInt64(utils.GetEnvWithDefault("upload.remove_expire", "2")) * 3600, + } } fileInfo := handler(old_fileInfo) + fileInfo.UpdatedAt = time.Now().Unix() jsonData, err := json.Marshal(fileInfo) if err != nil { - return err + return nil, err } - return rdb.Do(ctx, rdb.B().Hset().Key("015:fileInfoMap").FieldValue().FieldValue(fileId, string(jsonData)).Build()).Error() + if err := rdb.Do(ctx, rdb.B().Hset().Key("015:fileInfoMap").FieldValue().FieldValue(fileId, string(jsonData)).Build()).Error(); err != nil { + return nil, err + } + return fileInfo, nil } func GetRedisFileInfoAll() (map[string]string, error) { diff --git a/pkg/models/share.go b/pkg/models/share.go index 5e8f055..999d950 100644 --- a/pkg/models/share.go +++ b/pkg/models/share.go @@ -13,6 +13,7 @@ import ( type RedisShareInfo struct { // Id string `json:"id"` CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` Owner string `json:"owner"` Type ShareType `json:"type"` Data string `json:"data"` // 分享数据 文件分享为文件id 文本分享为文本内容 @@ -51,26 +52,32 @@ func GetRedisShareInfo(shareId string) (*RedisShareInfo, error) { return &shareInfoData, nil } -func SetRedisShareInfo(shareId string, handler func(shareInfo *RedisShareInfo) *RedisShareInfo) error { +func SetRedisShareInfo(shareId string, handler func(shareInfo *RedisShareInfo) *RedisShareInfo) (*RedisShareInfo, error) { rdb, ctx := utils.GetRedisClient() old_shareInfo, err := GetRedisShareInfo(shareId) if err != nil { - return err + return nil, err } if old_shareInfo == nil { - old_shareInfo = &RedisShareInfo{} + old_shareInfo = &RedisShareInfo{ + CreatedAt: time.Now().Unix(), + } } shareInfo := handler(old_shareInfo) + shareInfo.UpdatedAt = time.Now().Unix() jsonData, err := json.Marshal(shareInfo) if err != nil { - return err + return nil, err } - return rdb.Do( + if err := rdb.Do( ctx, rdb.B().Set(). Key(fmt.Sprintf("015:shareInfoMap:%s", shareId)). Value(string(jsonData)). Ex(time.Until(time.Unix(shareInfo.ExpireAt, 0))). Build(), - ).Error() + ).Error(); err != nil { + return nil, err + } + return shareInfo, nil } diff --git a/pkg/models/stat.go b/pkg/models/stat.go index a1d0560..3bbebfd 100644 --- a/pkg/models/stat.go +++ b/pkg/models/stat.go @@ -33,28 +33,32 @@ func GetRedisStat(key string) (*StatData, error) { return &stat, nil } -func SetRedisStat(key string, handler func(stat *StatData) *StatData) error { - return utils.WithLocker(context.Background(), "015:stat:"+key, 0, func(ctx context.Context) error { +func SetRedisStat(key string, handler func(stat *StatData) *StatData) (*StatData, error) { + var updatedStat *StatData + err := utils.WithLocker(context.Background(), "015:stat:"+key, 0, func(ctx context.Context) error { rdb, _ := utils.GetRedisClient() old_stat, err := GetRedisStat(key) if err != nil { return err } if old_stat == nil { - old_stat = &StatData{ - FileSize: 0, - FileNum: 0, - ShareNum: 0, - DownloadNum: 0, - } + old_stat = &StatData{} } stat := handler(old_stat) jsonData, err := json.Marshal(stat) if err != nil { return err } - return rdb.Do(ctx, rdb.B().Hset().Key("015:stat").FieldValue().FieldValue(key, string(jsonData)).Build()).Error() + if err := rdb.Do(ctx, rdb.B().Hset().Key("015:stat").FieldValue().FieldValue(key, string(jsonData)).Build()).Error(); err != nil { + return err + } + updatedStat = stat + return nil }) + if err != nil { + return nil, err + } + return updatedStat, nil } func GetRedisStatAll() (map[string]string, error) { diff --git a/worker/internal/services/file.go b/worker/internal/services/file.go index 189f3c9..8df06a9 100644 --- a/worker/internal/services/file.go +++ b/worker/internal/services/file.go @@ -53,8 +53,8 @@ func GenStandardFile(filePath string, mimeType string) (GenStandardFileReturn, e if err != nil { return GenStandardFileReturn{}, err } - if err := models.SetRedisFileInfo(fileId, models.RedisFileInfo{ - FileInfo: models.FileInfo{ + if err := models.SetRedisFileInfo(fileId, func(fileInfo *models.RedisFileInfo) *models.RedisFileInfo { + fileInfo.FileInfo = models.FileInfo{ FileSize: fileSize, FileHash: fileHash, MimeType: mimeType,