diff --git a/ext/twitter/main.go b/ext/twitter/main.go index 205bb6e..5b6d921 100644 --- a/ext/twitter/main.go +++ b/ext/twitter/main.go @@ -173,17 +173,22 @@ func GetTweetAPI( if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("invalid response code: %s", resp.Status) } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read body: %w", err) + } var apiResponse APIResponse - decoder := sonic.ConfigFastest.NewDecoder(resp.Body) - err = decoder.Decode(&apiResponse) + err = sonic.ConfigFastest.Unmarshal(body, &apiResponse) if err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } + result := apiResponse.Data.TweetResult.Result if result == nil { return nil, errors.New("failed to get tweet result") } + var tweet *Tweet if result.Tweet != nil { tweet = result.Tweet diff --git a/util/download.go b/util/download.go index e66d789..0149904 100644 --- a/util/download.go +++ b/util/download.go @@ -55,7 +55,6 @@ func DownloadFile( case <-ctx.Done(): return "", ctx.Err() default: - // create the download directory if it doesn't exist if err := EnsureDownloadDir(config.DownloadDir); err != nil { return "", err } @@ -184,26 +183,35 @@ func downloadInMemory( return nil, fmt.Errorf("file too large for in-memory download: %d bytes", resp.ContentLength) } - var bufPool = sync.Pool{ - New: func() any { - return bytes.NewBuffer(make([]byte, 0, 1024*1024)) - }, - } - - buf := bufPool.Get().(*bytes.Buffer) - buf.Reset() - defer bufPool.Put(buf) - + // allocate a single buffer with the + // correct size upfront to prevent reallocations + var data []byte if resp.ContentLength > 0 { - buf.Grow(int(resp.ContentLength)) + data = make([]byte, 0, resp.ContentLength) + } else { + // 64KB initial capacity + data = make([]byte, 0, 64*1024) } - _, err = io.Copy(buf, resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + // use a limited reader to prevent + // exceeding memory limits even if content-length is wrong + limitedReader := io.LimitReader(resp.Body, int64(config.MaxInMemory)) + + buf := make([]byte, 32*1024) // 32KB buffer + for { + n, err := limitedReader.Read(buf) + if n > 0 { + data = append(data, buf[:n]...) + } + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } } - return buf.Bytes(), nil + return data, nil } func EnsureDownloadDir(dir string) error { @@ -252,12 +260,15 @@ func runChunkedDownload( } } - chunks := createChunks(fileSize, config.ChunkSize) + numChunks := 1 + if fileSize > 0 { + numChunks = int(math.Ceil(float64(fileSize) / float64(config.ChunkSize))) + } semaphore := make(chan struct{}, config.Concurrency) var wg sync.WaitGroup - errChan := make(chan error, len(chunks)) + errChan := make(chan error, numChunks) var downloadErr error var errOnce sync.Once @@ -267,12 +278,22 @@ func runChunkedDownload( downloadCtx, cancelDownload := context.WithCancel(ctx) defer cancelDownload() - for idx, chunk := range chunks { + // use a mutex to synchronize file access + var fileMutex sync.Mutex + + for i := range numChunks { wg.Add(1) - go func(idx int, chunk [2]int) { + go func(chunkIndex int) { defer wg.Done() + // calculate chunk bounds + start := chunkIndex * config.ChunkSize + end := start + config.ChunkSize - 1 + if end >= fileSize && fileSize > 0 { + end = fileSize - 1 + } + // respect concurrency limit select { case semaphore <- struct{}{}: @@ -281,36 +302,31 @@ func runChunkedDownload( return } - chunkData, err := downloadChunkWithRetry(downloadCtx, fileURL, chunk, config) + err := downloadChunkToFile( + downloadCtx, fileURL, + file, start, end, + config, &fileMutex, + ) if err != nil { errOnce.Do(func() { - downloadErr = fmt.Errorf("chunk %d: %w", idx, err) + downloadErr = fmt.Errorf("chunk %d: %w", chunkIndex, err) cancelDownload() // cancel all other downloads errChan <- downloadErr }) return } - if err := writeChunkToFile(file, chunkData, chunk[0]); err != nil { - errOnce.Do(func() { - downloadErr = fmt.Errorf("failed to write chunk %d: %w", idx, err) - cancelDownload() - errChan <- downloadErr - }) - return - } - // update progress - chunkSize := chunk[1] - chunk[0] + 1 + chunkSize := end - start + 1 completedChunks.Add(1) completedBytes.Add(int64(chunkSize)) - progress := float64(completedBytes.Load()) / float64(fileSize) - - // report progress if handler exists - if config.ProgressUpdater != nil { - config.ProgressUpdater(progress) + if fileSize > 0 { + progress := float64(completedBytes.Load()) / float64(fileSize) + if config.ProgressUpdater != nil { + config.ProgressUpdater(progress) + } } - }(idx, chunk) + }(i) } done := make(chan struct{}) @@ -352,7 +368,11 @@ func runChunkedDownload( return nil } -func getFileSize(ctx context.Context, fileURL string, timeout time.Duration) (int, error) { +func getFileSize( + ctx context.Context, + fileURL string, + timeout time.Duration, +) (int, error) { reqCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -374,12 +394,15 @@ func getFileSize(ctx context.Context, fileURL string, timeout time.Duration) (in return int(resp.ContentLength), nil } -func downloadChunkWithRetry( +func downloadChunkToFile( ctx context.Context, fileURL string, - chunk [2]int, + file *os.File, + start int, + end int, config *models.DownloadConfig, -) ([]byte, error) { + fileMutex *sync.Mutex, +) error { var lastErr error for attempt := 0; attempt <= config.RetryAttempts; attempt++ { @@ -387,20 +410,72 @@ func downloadChunkWithRetry( // wait before retry select { case <-ctx.Done(): - return nil, ctx.Err() + return ctx.Err() case <-time.After(config.RetryDelay): } } - data, err := downloadChunk(ctx, fileURL, chunk, config.Timeout) + err := downloadAndWriteChunk( + ctx, fileURL, file, + start, end, config.Timeout, + fileMutex, + ) if err == nil { - return data, nil + return nil } lastErr = err } - return nil, fmt.Errorf("all %d attempts failed: %w", config.RetryAttempts+1, lastErr) + return fmt.Errorf("all %d attempts failed: %w", config.RetryAttempts+1, lastErr) +} + +func downloadAndWriteChunk( + ctx context.Context, + fileURL string, + file *os.File, + start int, + end int, + timeout time.Duration, + fileMutex *sync.Mutex, +) error { + reqCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, fileURL, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + + resp, err := downloadHTTPSession.Do(req) + if err != nil { + return fmt.Errorf("download failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + // use a fixed-size buffer for + // copying to avoid large allocations (32KB) + buf := make([]byte, 32*1024) + + fileMutex.Lock() + defer fileMutex.Unlock() + + if _, err := file.Seek(int64(start), io.SeekStart); err != nil { + return fmt.Errorf("failed to seek file: %w", err) + } + + _, err = io.CopyBuffer(file, resp.Body, buf) + if err != nil { + return fmt.Errorf("failed to write chunk data: %w", err) + } + + return nil } func downloadFile( @@ -421,85 +496,28 @@ func downloadFile( return "", fmt.Errorf("failed to download file: %w", err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) } + file, err := os.Create(filePath) if err != nil { return "", fmt.Errorf("failed to create file: %w", err) } defer file.Close() - _, err = io.Copy(file, resp.Body) + + // use a fixed-size buffer for + // copying to avoid large allocations (32KB) + buf := make([]byte, 32*1024) + _, err = io.CopyBuffer(file, resp.Body, buf) if err != nil { return "", fmt.Errorf("failed to write file: %w", err) } + return filePath, nil } -func downloadChunk( - ctx context.Context, - fileURL string, - chunk [2]int, - timeout time.Duration, -) ([]byte, error) { - reqCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, fileURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", chunk[0], chunk[1])) - - resp, err := downloadHTTPSession.Do(req) - if err != nil { - return nil, fmt.Errorf("download failed: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - var buf bytes.Buffer - if resp.ContentLength > 0 { - buf.Grow(int(resp.ContentLength)) - } else { - buf.Grow(chunk[1] - chunk[0] + 1) - } - _, err = io.Copy(&buf, resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read chunk data: %w", err) - } - - return buf.Bytes(), nil -} - -func writeChunkToFile(file *os.File, data []byte, offset int) error { - _, err := file.WriteAt(data, int64(offset)) - return err -} - -func createChunks(fileSize int, chunkSize int) [][2]int { - if fileSize <= 0 { - return [][2]int{{0, 0}} - } - - numChunks := int(math.Ceil(float64(fileSize) / float64(chunkSize))) - chunks := make([][2]int, numChunks) - - for i := range chunks { - start := i * chunkSize - end := start + chunkSize - 1 - if end >= fileSize { - end = fileSize - 1 - } - chunks[i] = [2]int{start, end} - } - - return chunks -} - func downloadSegments( ctx context.Context, path string,