diff --git a/backend/internal/controllers/file.go b/backend/internal/controllers/file.go index 94c255c..be600c6 100644 --- a/backend/internal/controllers/file.go +++ b/backend/internal/controllers/file.go @@ -5,10 +5,9 @@ import ( "backend/internal/utils" "encoding/json" "errors" - "fmt" + "math" "mime/multipart" "os" - "path/filepath" "pkg/models" u "pkg/utils" "time" @@ -28,9 +27,20 @@ func CreateUploadTask(c echo.Context) error { return utils.HTTPErrorHandler(c, errors.New("调用接口参数错误")) } fileId := utils.GetFileId(r.FileHash, r.FileSize) - fileInfo, _ := models.GetRedisFileInfo(fileId) + fileInfo, err := models.GetRedisFileInfo(fileId) + if err != nil { + return utils.HTTPErrorHandler(c, err) + } if fileInfo != nil { + uploadPath, err := utils.GetUploadDirPath() + if err != nil { + return utils.HTTPErrorHandler(c, err) + } + sliceList, err := services.GetFileSliceList(fileId, uploadPath) + if err != nil { + return utils.HTTPErrorHandler(c, err) + } return utils.HTTPSuccessHandler(c, map[string]any{ "size": fileInfo.FileSize, "mime_type": fileInfo.MimeType, @@ -39,6 +49,7 @@ func CreateUploadTask(c echo.Context) error { "expire": fileInfo.Expire, "id": fileId, "chunk_size": fileInfo.ChunkSize, + "chunks": sliceList, }) } maxStorageSize, err := utils.GetFileSize(u.GetEnv("upload.maximum")) @@ -155,7 +166,7 @@ func UploadFileSlice(c echo.Context) error { return utils.HTTPErrorHandler(c, err) } - if err := services.CreateFileSlice(file, uploadPath, r.FileId, r.FileIndex); err != nil { + if _, err := services.CreateFileSlice(r.FileId, uploadPath, file, r.FileIndex); err != nil { return utils.HTTPErrorHandler(c, err) } @@ -192,13 +203,23 @@ func FinishUploadTask(c echo.Context) error { return utils.HTTPErrorHandler(c, errors.New("上传任务已过期")) } - // 合并文件切片 - uploadPath, _ := utils.GetUploadDirPath() - slicesPath := filepath.Join(uploadPath, fmt.Sprintf("%s_%s", r.FileId, "tmp")) + uploadPath, err := utils.GetUploadDirPath() + if err != nil { + return utils.HTTPErrorHandler(c, err) + } + + fileSliceList, err := services.GetFileSliceList(r.FileId, uploadPath) + if err != nil { + return utils.HTTPErrorHandler(c, err) + } + + if len(fileSliceList) != int(math.Ceil(float64(fileInfo.FileSize)/float64(fileInfo.ChunkSize))) { + return utils.HTTPErrorHandler(c, errors.New("文件切片不完整")) + } // 最终合并后的文件路径 - mergeFilePath := filepath.Join(uploadPath, r.FileId) - if err := services.MergeFileSlices(slicesPath, mergeFilePath); err != nil { + mergeFilePath, err := services.MergeFileSlices(r.FileId, uploadPath) + if err != nil { return utils.HTTPErrorHandler(c, err) } diff --git a/backend/internal/services/file.go b/backend/internal/services/file.go index 4be5a9f..b4e5fba 100644 --- a/backend/internal/services/file.go +++ b/backend/internal/services/file.go @@ -5,59 +5,71 @@ import ( "io" "os" "path/filepath" + "sort" "strconv" ) -func CreateFileSlice(fileSlice io.Reader, uploadPath string, fileId string, fileIndex int64) error { +func CreateFileSlice(fileId string, uploadPath string, fileSlice io.Reader, fileIndex int64) (string, error) { filePath := filepath.Join(uploadPath, fmt.Sprintf("%s_%s", fileId, "tmp")) if err := os.MkdirAll(filePath, 0755); err != nil { - return err + return "", err } dst, err := os.Create(filepath.Join(filePath, fmt.Sprintf("%d", fileIndex))) if err != nil { - return err + return "", err } defer dst.Close() if _, err = io.Copy(dst, fileSlice); err != nil { - return err + return "", err } - return nil + return filePath, nil } -// MergeFileSlices 合并文件切片 -func MergeFileSlices(slicesPath string, mergeFilePath string) error { - // 创建最终文件 - destFile, err := os.Create(mergeFilePath) - if err != nil { - return fmt.Errorf("创建合并文件失败: %v", err) - } - defer destFile.Close() - - // 读取目录下的所有文件 +func GetFileSliceList(fileId string, uploadPath string) ([]int, error) { + slicesPath := filepath.Join(uploadPath, fmt.Sprintf("%s_%s", fileId, "tmp")) files, err := os.ReadDir(slicesPath) if err != nil { - return fmt.Errorf("读取切片目录失败: %v", err) + return nil, fmt.Errorf("读取切片目录失败: %v", err) } - - // 按照索引排序文件切片 - sliceFiles := make([]string, len(files)) + fileSliceList := []int{} for _, file := range files { index, err := strconv.Atoi(file.Name()) if err != nil { - return fmt.Errorf("无效的切片文件名: %v", err) + return nil, fmt.Errorf("无效的切片文件名: %v", err) } - sliceFiles[index-1] = filepath.Join(slicesPath, file.Name()) + fileSliceList = append(fileSliceList, index) + } + sort.Ints(fileSliceList) + return fileSliceList, nil +} + +// MergeFileSlices 合并文件切片 +func MergeFileSlices(fileId string, uploadPath string) (string, error) { + mergeFilePath := filepath.Join(uploadPath, fileId) + slicesPath := filepath.Join(uploadPath, fmt.Sprintf("%s_%s", fileId, "tmp")) + defer os.RemoveAll(slicesPath) + // 创建最终文件 + destFile, err := os.Create(mergeFilePath) + if err != nil { + return "", fmt.Errorf("创建合并文件失败: %v", err) + } + defer destFile.Close() + + fileSliceList, err := GetFileSliceList(fileId, uploadPath) + if err != nil { + return "", err } // 合并文件 buffer := make([]byte, 4*1024*1024) // 4MB buffer - for _, sliceFile := range sliceFiles { - sf, err := os.Open(sliceFile) + for _, index := range fileSliceList { + sliceFilePath := filepath.Join(slicesPath, fmt.Sprintf("%d", index)) + sf, err := os.Open(sliceFilePath) if err != nil { - return fmt.Errorf("打开切片文件失败: %v", err) + return "", fmt.Errorf("打开切片文件失败: %v", err) } for { @@ -67,17 +79,15 @@ func MergeFileSlices(slicesPath string, mergeFilePath string) error { } if err != nil { sf.Close() - return fmt.Errorf("读取切片文件失败: %v", err) + return "", fmt.Errorf("读取切片文件失败: %v", err) } if _, err := destFile.Write(buffer[:n]); err != nil { sf.Close() - return fmt.Errorf("写入合并文件失败: %v", err) + return "", fmt.Errorf("写入合并文件失败: %v", err) } } - sf.Close() } - defer os.RemoveAll(slicesPath) - return nil + return mergeFilePath, nil }