use ctx with cancel to (hopefully) avoid deadlocks

This commit is contained in:
stefanodvx 2025-04-15 17:51:06 +02:00
parent 8d972ff74b
commit 1a27103347
6 changed files with 50 additions and 30 deletions

View file

@ -1,6 +1,7 @@
package core package core
import ( import (
"context"
"fmt" "fmt"
"govd/database" "govd/database"
"govd/models" "govd/models"
@ -12,6 +13,7 @@ import (
func HandleDefaultFormatDownload( func HandleDefaultFormatDownload(
bot *gotgbot.Bot, bot *gotgbot.Bot,
ctx *ext.Context, ctx *ext.Context,
taskCtx context.Context,
dlCtx *models.DownloadContext, dlCtx *models.DownloadContext,
) error { ) error {
storedMedias, err := database.GetDefaultMedias( storedMedias, err := database.GetDefaultMedias(
@ -51,7 +53,7 @@ func HandleDefaultFormatDownload(
mediaList[i].Format = defaultFormat mediaList[i].Format = defaultFormat
} }
medias, err := DownloadMedias(mediaList, nil) medias, err := DownloadMedias(taskCtx, mediaList, nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to download media list: %w", err) return fmt.Errorf("failed to download media list: %w", err)
} }

View file

@ -86,17 +86,16 @@ func downloadMediaItem(
} }
func StartDownloadTask( func StartDownloadTask(
ctx context.Context,
media *models.Media, media *models.Media,
idx int, idx int,
config *models.DownloadConfig, config *models.DownloadConfig,
) (*models.DownloadedMedia, error) { ) (*models.DownloadedMedia, error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
return downloadMediaItem(ctx, media, config, idx) return downloadMediaItem(ctx, media, config, idx)
} }
func StartConcurrentDownload( func StartConcurrentDownload(
ctx context.Context,
media *models.Media, media *models.Media,
resultsChan chan<- models.DownloadedMedia, resultsChan chan<- models.DownloadedMedia,
config *models.DownloadConfig, config *models.DownloadConfig,
@ -106,9 +105,6 @@ func StartConcurrentDownload(
) { ) {
defer wg.Done() defer wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
result, err := downloadMediaItem(ctx, media, config, idx) result, err := downloadMediaItem(ctx, media, config, idx)
if err != nil { if err != nil {
errChan <- err errChan <- err
@ -119,13 +115,15 @@ func StartConcurrentDownload(
} }
func DownloadMedia( func DownloadMedia(
ctx context.Context,
media *models.Media, media *models.Media,
config *models.DownloadConfig, config *models.DownloadConfig,
) (*models.DownloadedMedia, error) { ) (*models.DownloadedMedia, error) {
return StartDownloadTask(media, 0, config) return StartDownloadTask(ctx, media, 0, config)
} }
func DownloadMedias( func DownloadMedias(
ctx context.Context,
medias []*models.Media, medias []*models.Media,
config *models.DownloadConfig, config *models.DownloadConfig,
) ([]*models.DownloadedMedia, error) { ) ([]*models.DownloadedMedia, error) {
@ -134,7 +132,7 @@ func DownloadMedias(
} }
if len(medias) == 1 { if len(medias) == 1 {
result, err := DownloadMedia(medias[0], config) result, err := DownloadMedia(ctx, medias[0], config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -147,7 +145,7 @@ func DownloadMedias(
for idx, media := range medias { for idx, media := range medias {
wg.Add(1) wg.Add(1)
go StartConcurrentDownload(media, resultsChan, config, errChan, &wg, idx) go StartConcurrentDownload(ctx, media, resultsChan, config, errChan, &wg, idx)
} }
go func() { go func() {
@ -158,19 +156,26 @@ func DownloadMedias(
var results []*models.DownloadedMedia var results []*models.DownloadedMedia
var firstError error var firstError error
received := 0
for received < len(medias) {
select { select {
case err := <-errChan: case result, ok := <-resultsChan:
if err != nil { if ok {
firstError = err resultCopy := result
}
default:
// no errors (yet)
}
for result := range resultsChan {
resultCopy := result // create a copy to avoid pointer issues
results = append(results, &resultCopy) results = append(results, &resultCopy)
received++
}
case err, ok := <-errChan:
if ok && firstError == nil {
firstError = err
received++
}
case <-ctx.Done():
if firstError == nil {
firstError = ctx.Err()
}
received++
}
} }
if firstError != nil { if firstError != nil {

View file

@ -1,6 +1,7 @@
package core package core
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"sync" "sync"
@ -205,6 +206,7 @@ func StartInlineTask(
} }
func GetInlineFormat( func GetInlineFormat(
taskCtx context.Context,
bot *gotgbot.Bot, bot *gotgbot.Bot,
ctx *ext.Context, ctx *ext.Context,
dlCtx *models.DownloadContext, dlCtx *models.DownloadContext,
@ -239,7 +241,7 @@ func GetInlineFormat(
mediaList[i].Format = defaultFormat mediaList[i].Format = defaultFormat
} }
messageCaption := FormatCaption(mediaList[0], true) messageCaption := FormatCaption(mediaList[0], true)
medias, err := DownloadMedias(mediaList, nil) medias, err := DownloadMedias(taskCtx, mediaList, nil)
if err != nil { if err != nil {
errChan <- fmt.Errorf("failed to download medias: %w", err) errChan <- fmt.Errorf("failed to download medias: %w", err)
return return

View file

@ -20,17 +20,13 @@ import (
func HandleDownloadRequest( func HandleDownloadRequest(
bot *gotgbot.Bot, bot *gotgbot.Bot,
ctx *ext.Context, ctx *ext.Context,
taskCtx context.Context,
dlCtx *models.DownloadContext, dlCtx *models.DownloadContext,
) error { ) error {
taskCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
dlCtx.Context = taskCtx
chatID := ctx.EffectiveMessage.Chat.Id chatID := ctx.EffectiveMessage.Chat.Id
if dlCtx.Extractor.Type == enums.ExtractorTypeSingle { if dlCtx.Extractor.Type == enums.ExtractorTypeSingle {
TypingEffect(bot, ctx, chatID) TypingEffect(bot, ctx, chatID)
err := HandleDefaultFormatDownload(bot, ctx, dlCtx) err := HandleDefaultFormatDownload(bot, ctx, taskCtx, dlCtx)
if err != nil { if err != nil {
return err return err
} }

View file

@ -34,6 +34,7 @@ func InlineDownloadHandler(
}) })
return nil return nil
} }
return core.HandleInline(bot, ctx, dlCtx) return core.HandleInline(bot, ctx, dlCtx)
} }
@ -56,7 +57,14 @@ func InlineDownloadResultHandler(
) )
defer cancel() defer cancel()
taskCtx, cancel := context.WithTimeout(
context.Background(),
5*time.Minute,
)
defer cancel()
go core.GetInlineFormat( go core.GetInlineFormat(
taskCtx,
bot, ctx, dlCtx, bot, ctx, dlCtx,
mediaChan, errChan, mediaChan, errChan,
) )

View file

@ -1,9 +1,11 @@
package handlers package handlers
import ( import (
"context"
"govd/bot/core" "govd/bot/core"
"govd/database" "govd/database"
extractors "govd/ext" extractors "govd/ext"
"time"
"github.com/PaulSonOfLars/gotgbot/v2" "github.com/PaulSonOfLars/gotgbot/v2"
"github.com/PaulSonOfLars/gotgbot/v2/ext" "github.com/PaulSonOfLars/gotgbot/v2/ext"
@ -39,7 +41,12 @@ func URLHandler(bot *gotgbot.Bot, ctx *ext.Context) error {
return err return err
} }
} }
err = core.HandleDownloadRequest(bot, ctx, dlCtx)
taskCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
err = core.HandleDownloadRequest(
bot, ctx, taskCtx, dlCtx)
if err != nil { if err != nil {
core.HandleErrorMessage( core.HandleErrorMessage(
bot, ctx, err) bot, ctx, err)