refactor(backend): enhance DownloadShare functionality with context management and improve Redis share info handling

This commit is contained in:
keven1024
2026-04-06 11:49:32 +08:00
parent 83f6be0486
commit 1298bf79d3
5 changed files with 94 additions and 70 deletions

View File

@@ -2,6 +2,7 @@ package controllers
import (
"backend/internal/utils"
"context"
"fmt"
"pkg/models"
u "pkg/utils"
@@ -9,6 +10,7 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v5"
"github.com/samber/lo"
"github.com/spf13/cast"
)
@@ -84,64 +86,65 @@ func VaildateShare(c *echo.Context) error {
return utils.HTTPErrorHandler(c, ErrInvalidSharePassword)
}
}
// 如果下载次数为0则设置为-1 防止空值问题
if shareInfo.ViewNum < 1 {
return utils.HTTPErrorHandler(c, ErrInsufficientDownloadQuota)
}
downloadWindow := u.GetEnvWithDefault("share.download_window", "12")
token := jwt.NewWithClaims(jwt.SigningMethodHS256, DownloadShareClaims{
ShareId: r.ShareId,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(cast.ToDuration(downloadWindow + "h"))),
},
})
return u.WithLocker(context.Background(), "015:shareInfoMap:"+r.ShareId, 0, func(ctx context.Context) error {
shareInfo, err := models.GetRedisShareInfo(r.ShareId)
if err != nil || shareInfo == nil {
return utils.HTTPErrorHandler(c, lo.Ternary(err != nil, err, ErrShareNotFound))
}
if shareInfo.ViewNum < 1 {
return utils.HTTPErrorHandler(c, ErrInsufficientDownloadQuota)
}
downloadWindow := u.GetEnvWithDefault("share.download_window", "12")
token := jwt.NewWithClaims(jwt.SigningMethodHS256, DownloadShareClaims{
ShareId: r.ShareId,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(cast.ToDuration(downloadWindow + "h"))),
},
})
// Sign and get the complete encoded token as a string using the secret
downloadToken, err := token.SignedString([]byte(u.GetEnv("share.download_secret")))
if err != nil {
return utils.HTTPErrorHandler(c, err)
}
if shareInfo.Type == models.ShareTypeFile {
fileInfo, err := models.GetRedisFileInfo(shareInfo.Data)
// Sign and get the complete encoded token as a string using the secret
downloadToken, err := token.SignedString([]byte(u.GetEnv("share.download_secret")))
if err != nil {
return utils.HTTPErrorHandler(c, err)
}
if fileInfo == nil {
return utils.HTTPErrorHandler(c, ErrShareFileNotFound)
if shareInfo.Type == models.ShareTypeFile {
fileInfo, err := models.GetRedisFileInfo(shareInfo.Data)
if err != nil {
return utils.HTTPErrorHandler(c, err)
}
if fileInfo == nil {
return utils.HTTPErrorHandler(c, ErrShareFileNotFound)
}
if fileInfo.FileType != models.FileTypeUpload {
return utils.HTTPErrorHandler(c, ErrInvalidShareFileState)
}
}
if fileInfo.FileType != models.FileTypeUpload {
return utils.HTTPErrorHandler(c, ErrInvalidShareFileState)
// download_nums 必须放在创建token的时候减掉不然多线程下载会导致多次减掉
err = models.SetRedisShareInfo(r.ShareId, func(shareInfo *models.RedisShareInfo) *models.RedisShareInfo {
shareInfo.ViewNum -= 1
return shareInfo
})
if err != nil {
return utils.HTTPErrorHandler(c, err)
}
}
// download_nums 必须放在创建token的时候减掉不然多线程下载会导致多次减掉
latestViewNum := shareInfo.ViewNum - 1
// 如果下载次数为0则设置为-1 防止空值问题
if latestViewNum < 1 {
latestViewNum = -1
}
err = models.SetRedisShareInfo(r.ShareId, models.RedisShareInfo{
ViewNum: latestViewNum,
})
if err != nil {
return utils.HTTPErrorHandler(c, err)
}
// 统计分享数
currentDate := time.Now().Format("2006-01-02")
err = models.SetRedisStat(currentDate, func(stat *models.StatData) *models.StatData {
stat.DownloadNum += 1
return stat
})
if err != nil {
return utils.HTTPErrorHandler(c, err)
}
// 统计分享数
currentDate := time.Now().Format("2006-01-02")
err = models.SetRedisStat(currentDate, func(stat *models.StatData) *models.StatData {
stat.DownloadNum += 1
return stat
})
if err != nil {
return utils.HTTPErrorHandler(c, err)
}
if shareInfo.Type == models.ShareTypeFile {
if shareInfo.Type == models.ShareTypeFile {
return utils.HTTPSuccessHandler(c, map[string]any{
"token": downloadToken,
})
}
return utils.HTTPSuccessHandler(c, map[string]any{
"token": downloadToken,
})
}
return utils.HTTPSuccessHandler(c, map[string]any{
"token": downloadToken,
})
}

View File

@@ -3,8 +3,10 @@ package models
import (
"encoding/json"
"pkg/utils"
"time"
"github.com/redis/rueidis"
"github.com/spf13/cast"
)
type FileInfo struct {
@@ -25,6 +27,7 @@ type RedisFileInfo struct {
FileInfo
FileType FileType `json:"type"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
Expire int64 `json:"expire"` // 只有上传文件(init)的时候有这个字段
}
@@ -44,21 +47,28 @@ func GetRedisFileInfo(fileId string) (*RedisFileInfo, error) {
return &fileInfoData, nil
}
func SetRedisFileInfo(fileId string, handler func(fileInfo *RedisFileInfo) *RedisFileInfo) error {
func SetRedisFileInfo(fileId string, handler func(fileInfo *RedisFileInfo) *RedisFileInfo) (*RedisFileInfo, error) {
rdb, ctx := utils.GetRedisClient()
old_fileInfo, err := GetRedisFileInfo(fileId)
if err != nil {
return err
return nil, err
}
if old_fileInfo == nil {
old_fileInfo = &RedisFileInfo{}
old_fileInfo = &RedisFileInfo{
CreatedAt: time.Now().Unix(),
Expire: cast.ToInt64(utils.GetEnvWithDefault("upload.remove_expire", "2")) * 3600,
}
}
fileInfo := handler(old_fileInfo)
fileInfo.UpdatedAt = time.Now().Unix()
jsonData, err := json.Marshal(fileInfo)
if err != nil {
return err
return nil, err
}
return rdb.Do(ctx, rdb.B().Hset().Key("015:fileInfoMap").FieldValue().FieldValue(fileId, string(jsonData)).Build()).Error()
if err := rdb.Do(ctx, rdb.B().Hset().Key("015:fileInfoMap").FieldValue().FieldValue(fileId, string(jsonData)).Build()).Error(); err != nil {
return nil, err
}
return fileInfo, nil
}
func GetRedisFileInfoAll() (map[string]string, error) {

View File

@@ -13,6 +13,7 @@ import (
type RedisShareInfo struct {
// Id string `json:"id"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
Owner string `json:"owner"`
Type ShareType `json:"type"`
Data string `json:"data"` // 分享数据 文件分享为文件id 文本分享为文本内容
@@ -51,26 +52,32 @@ func GetRedisShareInfo(shareId string) (*RedisShareInfo, error) {
return &shareInfoData, nil
}
func SetRedisShareInfo(shareId string, handler func(shareInfo *RedisShareInfo) *RedisShareInfo) error {
func SetRedisShareInfo(shareId string, handler func(shareInfo *RedisShareInfo) *RedisShareInfo) (*RedisShareInfo, error) {
rdb, ctx := utils.GetRedisClient()
old_shareInfo, err := GetRedisShareInfo(shareId)
if err != nil {
return err
return nil, err
}
if old_shareInfo == nil {
old_shareInfo = &RedisShareInfo{}
old_shareInfo = &RedisShareInfo{
CreatedAt: time.Now().Unix(),
}
}
shareInfo := handler(old_shareInfo)
shareInfo.UpdatedAt = time.Now().Unix()
jsonData, err := json.Marshal(shareInfo)
if err != nil {
return err
return nil, err
}
return rdb.Do(
if err := rdb.Do(
ctx,
rdb.B().Set().
Key(fmt.Sprintf("015:shareInfoMap:%s", shareId)).
Value(string(jsonData)).
Ex(time.Until(time.Unix(shareInfo.ExpireAt, 0))).
Build(),
).Error()
).Error(); err != nil {
return nil, err
}
return shareInfo, nil
}

View File

@@ -33,28 +33,32 @@ func GetRedisStat(key string) (*StatData, error) {
return &stat, nil
}
func SetRedisStat(key string, handler func(stat *StatData) *StatData) error {
return utils.WithLocker(context.Background(), "015:stat:"+key, 0, func(ctx context.Context) error {
func SetRedisStat(key string, handler func(stat *StatData) *StatData) (*StatData, error) {
var updatedStat *StatData
err := utils.WithLocker(context.Background(), "015:stat:"+key, 0, func(ctx context.Context) error {
rdb, _ := utils.GetRedisClient()
old_stat, err := GetRedisStat(key)
if err != nil {
return err
}
if old_stat == nil {
old_stat = &StatData{
FileSize: 0,
FileNum: 0,
ShareNum: 0,
DownloadNum: 0,
}
old_stat = &StatData{}
}
stat := handler(old_stat)
jsonData, err := json.Marshal(stat)
if err != nil {
return err
}
return rdb.Do(ctx, rdb.B().Hset().Key("015:stat").FieldValue().FieldValue(key, string(jsonData)).Build()).Error()
if err := rdb.Do(ctx, rdb.B().Hset().Key("015:stat").FieldValue().FieldValue(key, string(jsonData)).Build()).Error(); err != nil {
return err
}
updatedStat = stat
return nil
})
if err != nil {
return nil, err
}
return updatedStat, nil
}
func GetRedisStatAll() (map[string]string, error) {

View File

@@ -53,8 +53,8 @@ func GenStandardFile(filePath string, mimeType string) (GenStandardFileReturn, e
if err != nil {
return GenStandardFileReturn{}, err
}
if err := models.SetRedisFileInfo(fileId, models.RedisFileInfo{
FileInfo: models.FileInfo{
if err := models.SetRedisFileInfo(fileId, func(fileInfo *models.RedisFileInfo) *models.RedisFileInfo {
fileInfo.FileInfo = models.FileInfo{
FileSize: fileSize,
FileHash: fileHash,
MimeType: mimeType,