govd/util/download.go
stefanodvx 0d986d4573 fix: resolve memory leaks in download utility
1. writing chunks directly to disk instead of buffering in memory
2. using fixed-size buffers (32KB) for all I/O operations
3. optimizing buffer allocation strategy in downloadInMemory
4. implementing proper file synchronization with mutex locks
5. calculating chunk boundaries on-the-fly instead of pre-allocating

the memory profiling showed excessive allocations in bytes.growSlice
which has been addressed by minimizing intermediate buffers and
eliminating unnecessary memory copies

these changes should fix the observed OOM issues when downloading
large files while maintaining the same functionality
2025-04-28 14:35:22 +02:00

601 lines
13 KiB
Go

package util
import (
"bytes"
"context"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"runtime"
"sync"
"sync/atomic"
"time"
"govd/models"
"govd/util/av"
"github.com/google/uuid"
)
var downloadHTTPSession = GetDefaultHTTPClient()
func DefaultConfig() *models.DownloadConfig {
downloadsDir := os.Getenv("DOWNLOADS_DIR")
if downloadsDir == "" {
downloadsDir = "downloads"
}
return &models.DownloadConfig{
ChunkSize: 10 * 1024 * 1024, // 10MB
Concurrency: 4,
Timeout: 30 * time.Second,
DownloadDir: downloadsDir,
RetryAttempts: 3,
RetryDelay: 2 * time.Second,
Remux: true,
MaxInMemory: 50 * 1024 * 1024, // 50MB
}
}
func DownloadFile(
ctx context.Context,
URLList []string,
fileName string,
config *models.DownloadConfig,
) (string, error) {
if config == nil {
config = DefaultConfig()
}
var errs []error
for _, fileURL := range URLList {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
if err := EnsureDownloadDir(config.DownloadDir); err != nil {
return "", err
}
filePath := filepath.Join(config.DownloadDir, fileName)
err := runChunkedDownload(ctx, fileURL, filePath, config)
if err != nil {
errs = append(errs, err)
continue
}
if config.Remux {
err := av.RemuxFile(filePath)
if err != nil {
os.Remove(filePath)
return "", fmt.Errorf("remuxing failed: %w", err)
}
}
return filePath, nil
}
}
return "", fmt.Errorf("%w: %v", ErrDownloadFailed, errs)
}
func DownloadFileWithSegments(
ctx context.Context,
segmentURLs []string,
fileName string,
config *models.DownloadConfig,
) (string, error) {
if config == nil {
config = DefaultConfig()
}
if err := EnsureDownloadDir(config.DownloadDir); err != nil {
return "", err
}
tempDir := filepath.Join(
config.DownloadDir,
"segments"+uuid.NewString(),
)
if err := os.MkdirAll(tempDir, 0755); err != nil {
return "", fmt.Errorf("failed to create temporary directory: %w", err)
}
downloadedFiles, err := downloadSegments(ctx, tempDir, segmentURLs, config)
if err != nil {
os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to download segments: %w", err)
}
mergedFilePath, err := av.MergeSegments(downloadedFiles, fileName)
if err != nil {
os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to merge segments: %w", err)
}
if err := os.RemoveAll(tempDir); err != nil {
return "", fmt.Errorf("failed to remove temporary directory: %w", err)
}
return mergedFilePath, nil
}
func DownloadFileInMemory(
ctx context.Context,
URLList []string,
config *models.DownloadConfig,
) (*bytes.Reader, error) {
if config == nil {
config = DefaultConfig()
}
var errs []error
for _, fileURL := range URLList {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
data, err := downloadInMemory(
ctx, fileURL,
config,
)
if err != nil {
errs = append(errs, err)
continue
}
return bytes.NewReader(data), nil
}
}
return nil, fmt.Errorf("%w: %v", ErrDownloadFailed, errs)
}
func downloadInMemory(
ctx context.Context,
fileURL string,
config *models.DownloadConfig,
) ([]byte, error) {
reqCtx, cancel := context.WithTimeout(
ctx,
config.Timeout,
)
defer cancel()
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
// continue with the request
}
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, fileURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := downloadHTTPSession.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download file: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
if resp.ContentLength > int64(config.MaxInMemory) {
return nil, fmt.Errorf("file too large for in-memory download: %d bytes", resp.ContentLength)
}
// allocate a single buffer with the
// correct size upfront to prevent reallocations
var data []byte
if resp.ContentLength > 0 {
data = make([]byte, 0, resp.ContentLength)
} else {
// 64KB initial capacity
data = make([]byte, 0, 64*1024)
}
// use a limited reader to prevent
// exceeding memory limits even if content-length is wrong
limitedReader := io.LimitReader(resp.Body, int64(config.MaxInMemory))
buf := make([]byte, 32*1024) // 32KB buffer
for {
n, err := limitedReader.Read(buf)
if n > 0 {
data = append(data, buf[:n]...)
}
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
}
return data, nil
}
func EnsureDownloadDir(dir string) error {
if _, err := os.Stat(dir); err != nil {
if os.IsNotExist(err) {
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create downloads directory: %w", err)
}
} else {
return fmt.Errorf("error accessing directory: %w", err)
}
}
return nil
}
func runChunkedDownload(
ctx context.Context,
fileURL string,
filePath string,
config *models.DownloadConfig,
) error {
// reduce concurrency if it's greater
// than the number of available CPUs
maxProcs := runtime.GOMAXPROCS(0)
optimalConcurrency := int(math.Max(1, float64(maxProcs-1)))
if config.Concurrency > optimalConcurrency {
config.Concurrency = optimalConcurrency
}
fileSize, err := getFileSize(ctx, fileURL, config.Timeout)
if err != nil {
return err
}
file, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()
// pre-allocate file size if possible
if fileSize > 0 {
if err := file.Truncate(int64(fileSize)); err != nil {
return fmt.Errorf("failed to allocate file space: %w", err)
}
}
numChunks := 1
if fileSize > 0 {
numChunks = int(math.Ceil(float64(fileSize) / float64(config.ChunkSize)))
}
semaphore := make(chan struct{}, config.Concurrency)
var wg sync.WaitGroup
errChan := make(chan error, numChunks)
var downloadErr error
var errOnce sync.Once
var completedChunks atomic.Int64
var completedBytes atomic.Int64
downloadCtx, cancelDownload := context.WithCancel(ctx)
defer cancelDownload()
// use a mutex to synchronize file access
var fileMutex sync.Mutex
for i := range numChunks {
wg.Add(1)
go func(chunkIndex int) {
defer wg.Done()
// calculate chunk bounds
start := chunkIndex * config.ChunkSize
end := start + config.ChunkSize - 1
if end >= fileSize && fileSize > 0 {
end = fileSize - 1
}
// respect concurrency limit
select {
case semaphore <- struct{}{}:
defer func() { <-semaphore }()
case <-downloadCtx.Done():
return
}
err := downloadChunkToFile(
downloadCtx, fileURL,
file, start, end,
config, &fileMutex,
)
if err != nil {
errOnce.Do(func() {
downloadErr = fmt.Errorf("chunk %d: %w", chunkIndex, err)
cancelDownload() // cancel all other downloads
errChan <- downloadErr
})
return
}
// update progress
chunkSize := end - start + 1
completedChunks.Add(1)
completedBytes.Add(int64(chunkSize))
if fileSize > 0 {
progress := float64(completedBytes.Load()) / float64(fileSize)
if config.ProgressUpdater != nil {
config.ProgressUpdater(progress)
}
}
}(i)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(errChan)
close(done)
}()
var multiErr []error
select {
case err := <-errChan:
if err != nil {
multiErr = append(multiErr, err)
// collect all errors
for e := range errChan {
if e != nil {
multiErr = append(multiErr, e)
}
}
}
<-done
case <-ctx.Done():
cancelDownload()
<-done // wait for all goroutines to finish
os.Remove(filePath)
return ctx.Err()
case <-done:
// no errors
}
if len(multiErr) > 0 {
os.Remove(filePath)
return fmt.Errorf("multiple download errors: %v", multiErr)
}
return nil
}
func getFileSize(
ctx context.Context,
fileURL string,
timeout time.Duration,
) (int, error) {
reqCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodHead, fileURL, nil)
if err != nil {
return 0, fmt.Errorf("failed to create request: %w", err)
}
resp, err := downloadHTTPSession.Do(req)
if err != nil {
return 0, fmt.Errorf("failed to get file size: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("failed to get file info: status code %d", resp.StatusCode)
}
return int(resp.ContentLength), nil
}
func downloadChunkToFile(
ctx context.Context,
fileURL string,
file *os.File,
start int,
end int,
config *models.DownloadConfig,
fileMutex *sync.Mutex,
) error {
var lastErr error
for attempt := 0; attempt <= config.RetryAttempts; attempt++ {
if attempt > 0 {
// wait before retry
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(config.RetryDelay):
}
}
err := downloadAndWriteChunk(
ctx, fileURL, file,
start, end, config.Timeout,
fileMutex,
)
if err == nil {
return nil
}
lastErr = err
}
return fmt.Errorf("all %d attempts failed: %w", config.RetryAttempts+1, lastErr)
}
func downloadAndWriteChunk(
ctx context.Context,
fileURL string,
file *os.File,
start int,
end int,
timeout time.Duration,
fileMutex *sync.Mutex,
) error {
reqCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, fileURL, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := downloadHTTPSession.Do(req)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
// use a fixed-size buffer for
// copying to avoid large allocations (32KB)
buf := make([]byte, 32*1024)
fileMutex.Lock()
defer fileMutex.Unlock()
if _, err := file.Seek(int64(start), io.SeekStart); err != nil {
return fmt.Errorf("failed to seek file: %w", err)
}
_, err = io.CopyBuffer(file, resp.Body, buf)
if err != nil {
return fmt.Errorf("failed to write chunk data: %w", err)
}
return nil
}
func downloadFile(
ctx context.Context,
fileURL string,
filePath string,
timeout time.Duration,
) (string, error) {
reqCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, fileURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
resp, err := downloadHTTPSession.Do(req)
if err != nil {
return "", fmt.Errorf("failed to download file: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
file, err := os.Create(filePath)
if err != nil {
return "", fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()
// use a fixed-size buffer for
// copying to avoid large allocations (32KB)
buf := make([]byte, 32*1024)
_, err = io.CopyBuffer(file, resp.Body, buf)
if err != nil {
return "", fmt.Errorf("failed to write file: %w", err)
}
return filePath, nil
}
func downloadSegments(
ctx context.Context,
path string,
segmentURLs []string,
config *models.DownloadConfig,
) ([]string, error) {
if config == nil {
config = DefaultConfig()
}
semaphore := make(chan struct{}, config.Concurrency)
var wg sync.WaitGroup
var firstErr atomic.Value
downloadedFiles := make([]string, len(segmentURLs))
defer func() {
if firstErr.Load() != nil {
for _, path := range downloadedFiles {
if path != "" {
os.Remove(path)
}
}
}
}()
downloadCtx, cancelDownload := context.WithCancel(ctx)
defer cancelDownload()
for i, segmentURL := range segmentURLs {
wg.Add(1)
go func(idx int, url string) {
defer wg.Done()
select {
case <-downloadCtx.Done():
return
default:
// continue with the download
}
// acquire semaphore slot
semaphore <- struct{}{}
defer func() { <-semaphore }()
segmentFileName := fmt.Sprintf("segment_%05d", idx)
segmentPath := filepath.Join(path, segmentFileName)
filePath, err := downloadFile(
ctx, url, segmentPath,
config.Timeout,
)
if err != nil {
if firstErr.Load() == nil {
firstErr.Store(fmt.Errorf("failed to download segment %d: %w", idx, err))
cancelDownload()
}
return
}
downloadedFiles[idx] = filePath
}(i, segmentURL)
}
wg.Wait()
if err := firstErr.Load(); err != nil {
return nil, err.(error)
}
for i, file := range downloadedFiles {
if file == "" {
return nil, fmt.Errorf("segment %d was not downloaded", i)
}
if _, err := os.Stat(file); os.IsNotExist(err) {
return nil, fmt.Errorf("segment %d file does not exist: %w", i, err)
}
}
return downloadedFiles, nil
}