From 1a27103347da3a62f66973457eb4159eeaeff6f0 Mon Sep 17 00:00:00 2001 From: stefanodvx <69367859+stefanodvx@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:51:06 +0200 Subject: [PATCH] use ctx with cancel to (hopefully) avoid deadlocks --- bot/core/default.go | 4 +++- bot/core/download.go | 47 +++++++++++++++++++++++------------------- bot/core/inline.go | 4 +++- bot/core/main.go | 8 ++----- bot/handlers/inline.go | 8 +++++++ bot/handlers/url.go | 9 +++++++- 6 files changed, 50 insertions(+), 30 deletions(-) diff --git a/bot/core/default.go b/bot/core/default.go index e806c54..04ac38c 100644 --- a/bot/core/default.go +++ b/bot/core/default.go @@ -1,6 +1,7 @@ package core import ( + "context" "fmt" "govd/database" "govd/models" @@ -12,6 +13,7 @@ import ( func HandleDefaultFormatDownload( bot *gotgbot.Bot, ctx *ext.Context, + taskCtx context.Context, dlCtx *models.DownloadContext, ) error { storedMedias, err := database.GetDefaultMedias( @@ -51,7 +53,7 @@ func HandleDefaultFormatDownload( mediaList[i].Format = defaultFormat } - medias, err := DownloadMedias(mediaList, nil) + medias, err := DownloadMedias(taskCtx, mediaList, nil) if err != nil { return fmt.Errorf("failed to download media list: %w", err) } diff --git a/bot/core/download.go b/bot/core/download.go index f84d749..217670d 100644 --- a/bot/core/download.go +++ b/bot/core/download.go @@ -86,17 +86,16 @@ func downloadMediaItem( } func StartDownloadTask( + ctx context.Context, media *models.Media, idx int, config *models.DownloadConfig, ) (*models.DownloadedMedia, error) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return downloadMediaItem(ctx, media, config, idx) } func StartConcurrentDownload( + ctx context.Context, media *models.Media, resultsChan chan<- models.DownloadedMedia, config *models.DownloadConfig, @@ -106,9 +105,6 @@ func StartConcurrentDownload( ) { defer wg.Done() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - result, err := downloadMediaItem(ctx, media, config, idx) if err != nil { errChan <- err @@ -119,13 +115,15 @@ func StartConcurrentDownload( } func DownloadMedia( + ctx context.Context, media *models.Media, config *models.DownloadConfig, ) (*models.DownloadedMedia, error) { - return StartDownloadTask(media, 0, config) + return StartDownloadTask(ctx, media, 0, config) } func DownloadMedias( + ctx context.Context, medias []*models.Media, config *models.DownloadConfig, ) ([]*models.DownloadedMedia, error) { @@ -134,7 +132,7 @@ func DownloadMedias( } if len(medias) == 1 { - result, err := DownloadMedia(medias[0], config) + result, err := DownloadMedia(ctx, medias[0], config) if err != nil { return nil, err } @@ -147,7 +145,7 @@ func DownloadMedias( for idx, media := range medias { wg.Add(1) - go StartConcurrentDownload(media, resultsChan, config, errChan, &wg, idx) + go StartConcurrentDownload(ctx, media, resultsChan, config, errChan, &wg, idx) } go func() { @@ -158,19 +156,26 @@ func DownloadMedias( var results []*models.DownloadedMedia var firstError error - - select { - case err := <-errChan: - if err != nil { - firstError = err + received := 0 + for received < len(medias) { + select { + case result, ok := <-resultsChan: + if ok { + resultCopy := result + results = append(results, &resultCopy) + received++ + } + case err, ok := <-errChan: + if ok && firstError == nil { + firstError = err + received++ + } + case <-ctx.Done(): + if firstError == nil { + firstError = ctx.Err() + } + received++ } - default: - // no errors (yet) - } - - for result := range resultsChan { - resultCopy := result // create a copy to avoid pointer issues - results = append(results, &resultCopy) } if firstError != nil { diff --git a/bot/core/inline.go b/bot/core/inline.go index 6239dd3..6f090f7 100644 --- a/bot/core/inline.go +++ b/bot/core/inline.go @@ -1,6 +1,7 @@ package core import ( + "context" "fmt" "log" "sync" @@ -205,6 +206,7 @@ func StartInlineTask( } func GetInlineFormat( + taskCtx context.Context, bot *gotgbot.Bot, ctx *ext.Context, dlCtx *models.DownloadContext, @@ -239,7 +241,7 @@ func GetInlineFormat( mediaList[i].Format = defaultFormat } messageCaption := FormatCaption(mediaList[0], true) - medias, err := DownloadMedias(mediaList, nil) + medias, err := DownloadMedias(taskCtx, mediaList, nil) if err != nil { errChan <- fmt.Errorf("failed to download medias: %w", err) return diff --git a/bot/core/main.go b/bot/core/main.go index a844a56..0138f22 100644 --- a/bot/core/main.go +++ b/bot/core/main.go @@ -20,17 +20,13 @@ import ( func HandleDownloadRequest( bot *gotgbot.Bot, ctx *ext.Context, + taskCtx context.Context, dlCtx *models.DownloadContext, ) error { - taskCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) - defer cancel() - - dlCtx.Context = taskCtx - chatID := ctx.EffectiveMessage.Chat.Id if dlCtx.Extractor.Type == enums.ExtractorTypeSingle { TypingEffect(bot, ctx, chatID) - err := HandleDefaultFormatDownload(bot, ctx, dlCtx) + err := HandleDefaultFormatDownload(bot, ctx, taskCtx, dlCtx) if err != nil { return err } diff --git a/bot/handlers/inline.go b/bot/handlers/inline.go index cfc6d1d..1b28cf5 100644 --- a/bot/handlers/inline.go +++ b/bot/handlers/inline.go @@ -34,6 +34,7 @@ func InlineDownloadHandler( }) return nil } + return core.HandleInline(bot, ctx, dlCtx) } @@ -56,7 +57,14 @@ func InlineDownloadResultHandler( ) defer cancel() + taskCtx, cancel := context.WithTimeout( + context.Background(), + 5*time.Minute, + ) + defer cancel() + go core.GetInlineFormat( + taskCtx, bot, ctx, dlCtx, mediaChan, errChan, ) diff --git a/bot/handlers/url.go b/bot/handlers/url.go index 775fb0b..9ff34c3 100644 --- a/bot/handlers/url.go +++ b/bot/handlers/url.go @@ -1,9 +1,11 @@ package handlers import ( + "context" "govd/bot/core" "govd/database" extractors "govd/ext" + "time" "github.com/PaulSonOfLars/gotgbot/v2" "github.com/PaulSonOfLars/gotgbot/v2/ext" @@ -39,7 +41,12 @@ func URLHandler(bot *gotgbot.Bot, ctx *ext.Context) error { return err } } - err = core.HandleDownloadRequest(bot, ctx, dlCtx) + + taskCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + err = core.HandleDownloadRequest( + bot, ctx, taskCtx, dlCtx) if err != nil { core.HandleErrorMessage( bot, ctx, err)