From 0d986d4573f13d6c39af13f6b53a76a479ab0d1f Mon Sep 17 00:00:00 2001 From: stefanodvx <69367859+stefanodvx@users.noreply.github.com> Date: Mon, 28 Apr 2025 14:35:22 +0200 Subject: [PATCH] fix: resolve memory leaks in download utility 1. writing chunks directly to disk instead of buffering in memory 2. using fixed-size buffers (32KB) for all I/O operations 3. optimizing buffer allocation strategy in downloadInMemory 4. implementing proper file synchronization with mutex locks 5. calculating chunk boundaries on-the-fly instead of pre-allocating the memory profiling showed excessive allocations in bytes.growSlice which has been addressed by minimizing intermediate buffers and eliminating unnecessary memory copies these changes should fix the observed OOM issues when downloading large files while maintaining the same functionality --- ext/twitter/main.go | 9 +- util/download.go | 240 ++++++++++++++++++++++++-------------------- 2 files changed, 136 insertions(+), 113 deletions(-) diff --git a/ext/twitter/main.go b/ext/twitter/main.go index 205bb6e..5b6d921 100644 --- a/ext/twitter/main.go +++ b/ext/twitter/main.go @@ -173,17 +173,22 @@ func GetTweetAPI( if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("invalid response code: %s", resp.Status) } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read body: %w", err) + } var apiResponse APIResponse - decoder := sonic.ConfigFastest.NewDecoder(resp.Body) - err = decoder.Decode(&apiResponse) + err = sonic.ConfigFastest.Unmarshal(body, &apiResponse) if err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } + result := apiResponse.Data.TweetResult.Result if result == nil { return nil, errors.New("failed to get tweet result") } + var tweet *Tweet if result.Tweet != nil { tweet = result.Tweet diff --git a/util/download.go b/util/download.go index e66d789..0149904 100644 --- a/util/download.go +++ b/util/download.go @@ -55,7 +55,6 @@ func DownloadFile( case <-ctx.Done(): return "", ctx.Err() default: - // create the download directory if it doesn't exist if err := EnsureDownloadDir(config.DownloadDir); err != nil { return "", err } @@ -184,26 +183,35 @@ func downloadInMemory( return nil, fmt.Errorf("file too large for in-memory download: %d bytes", resp.ContentLength) } - var bufPool = sync.Pool{ - New: func() any { - return bytes.NewBuffer(make([]byte, 0, 1024*1024)) - }, - } - - buf := bufPool.Get().(*bytes.Buffer) - buf.Reset() - defer bufPool.Put(buf) - + // allocate a single buffer with the + // correct size upfront to prevent reallocations + var data []byte if resp.ContentLength > 0 { - buf.Grow(int(resp.ContentLength)) + data = make([]byte, 0, resp.ContentLength) + } else { + // 64KB initial capacity + data = make([]byte, 0, 64*1024) } - _, err = io.Copy(buf, resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + // use a limited reader to prevent + // exceeding memory limits even if content-length is wrong + limitedReader := io.LimitReader(resp.Body, int64(config.MaxInMemory)) + + buf := make([]byte, 32*1024) // 32KB buffer + for { + n, err := limitedReader.Read(buf) + if n > 0 { + data = append(data, buf[:n]...) + } + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } } - return buf.Bytes(), nil + return data, nil } func EnsureDownloadDir(dir string) error { @@ -252,12 +260,15 @@ func runChunkedDownload( } } - chunks := createChunks(fileSize, config.ChunkSize) + numChunks := 1 + if fileSize > 0 { + numChunks = int(math.Ceil(float64(fileSize) / float64(config.ChunkSize))) + } semaphore := make(chan struct{}, config.Concurrency) var wg sync.WaitGroup - errChan := make(chan error, len(chunks)) + errChan := make(chan error, numChunks) var downloadErr error var errOnce sync.Once @@ -267,12 +278,22 @@ func runChunkedDownload( downloadCtx, cancelDownload := context.WithCancel(ctx) defer cancelDownload() - for idx, chunk := range chunks { + // use a mutex to synchronize file access + var fileMutex sync.Mutex + + for i := range numChunks { wg.Add(1) - go func(idx int, chunk [2]int) { + go func(chunkIndex int) { defer wg.Done() + // calculate chunk bounds + start := chunkIndex * config.ChunkSize + end := start + config.ChunkSize - 1 + if end >= fileSize && fileSize > 0 { + end = fileSize - 1 + } + // respect concurrency limit select { case semaphore <- struct{}{}: @@ -281,36 +302,31 @@ func runChunkedDownload( return } - chunkData, err := downloadChunkWithRetry(downloadCtx, fileURL, chunk, config) + err := downloadChunkToFile( + downloadCtx, fileURL, + file, start, end, + config, &fileMutex, + ) if err != nil { errOnce.Do(func() { - downloadErr = fmt.Errorf("chunk %d: %w", idx, err) + downloadErr = fmt.Errorf("chunk %d: %w", chunkIndex, err) cancelDownload() // cancel all other downloads errChan <- downloadErr }) return } - if err := writeChunkToFile(file, chunkData, chunk[0]); err != nil { - errOnce.Do(func() { - downloadErr = fmt.Errorf("failed to write chunk %d: %w", idx, err) - cancelDownload() - errChan <- downloadErr - }) - return - } - // update progress - chunkSize := chunk[1] - chunk[0] + 1 + chunkSize := end - start + 1 completedChunks.Add(1) completedBytes.Add(int64(chunkSize)) - progress := float64(completedBytes.Load()) / float64(fileSize) - - // report progress if handler exists - if config.ProgressUpdater != nil { - config.ProgressUpdater(progress) + if fileSize > 0 { + progress := float64(completedBytes.Load()) / float64(fileSize) + if config.ProgressUpdater != nil { + config.ProgressUpdater(progress) + } } - }(idx, chunk) + }(i) } done := make(chan struct{}) @@ -352,7 +368,11 @@ func runChunkedDownload( return nil } -func getFileSize(ctx context.Context, fileURL string, timeout time.Duration) (int, error) { +func getFileSize( + ctx context.Context, + fileURL string, + timeout time.Duration, +) (int, error) { reqCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -374,12 +394,15 @@ func getFileSize(ctx context.Context, fileURL string, timeout time.Duration) (in return int(resp.ContentLength), nil } -func downloadChunkWithRetry( +func downloadChunkToFile( ctx context.Context, fileURL string, - chunk [2]int, + file *os.File, + start int, + end int, config *models.DownloadConfig, -) ([]byte, error) { + fileMutex *sync.Mutex, +) error { var lastErr error for attempt := 0; attempt <= config.RetryAttempts; attempt++ { @@ -387,20 +410,72 @@ func downloadChunkWithRetry( // wait before retry select { case <-ctx.Done(): - return nil, ctx.Err() + return ctx.Err() case <-time.After(config.RetryDelay): } } - data, err := downloadChunk(ctx, fileURL, chunk, config.Timeout) + err := downloadAndWriteChunk( + ctx, fileURL, file, + start, end, config.Timeout, + fileMutex, + ) if err == nil { - return data, nil + return nil } lastErr = err } - return nil, fmt.Errorf("all %d attempts failed: %w", config.RetryAttempts+1, lastErr) + return fmt.Errorf("all %d attempts failed: %w", config.RetryAttempts+1, lastErr) +} + +func downloadAndWriteChunk( + ctx context.Context, + fileURL string, + file *os.File, + start int, + end int, + timeout time.Duration, + fileMutex *sync.Mutex, +) error { + reqCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, fileURL, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + + resp, err := downloadHTTPSession.Do(req) + if err != nil { + return fmt.Errorf("download failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + // use a fixed-size buffer for + // copying to avoid large allocations (32KB) + buf := make([]byte, 32*1024) + + fileMutex.Lock() + defer fileMutex.Unlock() + + if _, err := file.Seek(int64(start), io.SeekStart); err != nil { + return fmt.Errorf("failed to seek file: %w", err) + } + + _, err = io.CopyBuffer(file, resp.Body, buf) + if err != nil { + return fmt.Errorf("failed to write chunk data: %w", err) + } + + return nil } func downloadFile( @@ -421,85 +496,28 @@ func downloadFile( return "", fmt.Errorf("failed to download file: %w", err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) } + file, err := os.Create(filePath) if err != nil { return "", fmt.Errorf("failed to create file: %w", err) } defer file.Close() - _, err = io.Copy(file, resp.Body) + + // use a fixed-size buffer for + // copying to avoid large allocations (32KB) + buf := make([]byte, 32*1024) + _, err = io.CopyBuffer(file, resp.Body, buf) if err != nil { return "", fmt.Errorf("failed to write file: %w", err) } + return filePath, nil } -func downloadChunk( - ctx context.Context, - fileURL string, - chunk [2]int, - timeout time.Duration, -) ([]byte, error) { - reqCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, fileURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", chunk[0], chunk[1])) - - resp, err := downloadHTTPSession.Do(req) - if err != nil { - return nil, fmt.Errorf("download failed: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - var buf bytes.Buffer - if resp.ContentLength > 0 { - buf.Grow(int(resp.ContentLength)) - } else { - buf.Grow(chunk[1] - chunk[0] + 1) - } - _, err = io.Copy(&buf, resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read chunk data: %w", err) - } - - return buf.Bytes(), nil -} - -func writeChunkToFile(file *os.File, data []byte, offset int) error { - _, err := file.WriteAt(data, int64(offset)) - return err -} - -func createChunks(fileSize int, chunkSize int) [][2]int { - if fileSize <= 0 { - return [][2]int{{0, 0}} - } - - numChunks := int(math.Ceil(float64(fileSize) / float64(chunkSize))) - chunks := make([][2]int, numChunks) - - for i := range chunks { - start := i * chunkSize - end := start + chunkSize - 1 - if end >= fileSize { - end = fileSize - 1 - } - chunks[i] = [2]int{start, end} - } - - return chunks -} - func downloadSegments( ctx context.Context, path string,