From c06c7958e25a16eba88febdae72bba7a801d555d Mon Sep 17 00:00:00 2001 From: stefanodvx <69367859+stefanodvx@users.noreply.github.com> Date: Tue, 15 Apr 2025 10:46:54 +0200 Subject: [PATCH] cleanup code --- bot/core/inline.go | 21 +++++++++++++++++++-- bot/core/main.go | 6 ++++++ bot/core/util.go | 10 ++++++++++ bot/handlers/help.go | 4 +++- bot/handlers/inline.go | 5 +++-- models/ctx.go | 3 +++ models/download.go | 1 + util/download.go | 17 ++++++++++++++--- util/http.go | 8 +++++--- util/img.go | 3 --- 10 files changed, 64 insertions(+), 14 deletions(-) diff --git a/bot/core/inline.go b/bot/core/inline.go index b1d052b..6239dd3 100644 --- a/bot/core/inline.go +++ b/bot/core/inline.go @@ -3,6 +3,7 @@ package core import ( "fmt" "log" + "sync" "github.com/google/uuid" "github.com/pkg/errors" @@ -16,7 +17,23 @@ import ( "github.com/PaulSonOfLars/gotgbot/v2/ext" ) -var InlineTasks = make(map[string]*models.DownloadContext) +var InlineTasks sync.Map + +func GetTask(id string) (*models.DownloadContext, bool) { + value, ok := InlineTasks.Load(id) + if !ok { + return nil, false + } + return value.(*models.DownloadContext), true +} + +func SetTask(id string, task *models.DownloadContext) { + InlineTasks.Store(id, task) +} + +func DeleteTask(id string) { + InlineTasks.Delete(id) +} func HandleInline( bot *gotgbot.Bot, @@ -183,7 +200,7 @@ func StartInlineTask( log.Println("failed to answer inline query") return nil } - InlineTasks[taskID] = dlCtx + SetTask(taskID, dlCtx) return nil } diff --git a/bot/core/main.go b/bot/core/main.go index 37c16da..a844a56 100644 --- a/bot/core/main.go +++ b/bot/core/main.go @@ -1,6 +1,7 @@ package core import ( + "context" "fmt" "os" "slices" @@ -21,6 +22,11 @@ func HandleDownloadRequest( ctx *ext.Context, dlCtx *models.DownloadContext, ) error { + taskCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + dlCtx.Context = taskCtx + chatID := ctx.EffectiveMessage.Chat.Id if dlCtx.Extractor.Type == enums.ExtractorTypeSingle { TypingEffect(bot, ctx, chatID) diff --git a/bot/core/util.go b/bot/core/util.go index 77ad514..b487852 100644 --- a/bot/core/util.go +++ b/bot/core/util.go @@ -198,6 +198,16 @@ func HandleErrorMessage( err error, ) { currentError := err + + if errors.As(currentError, context.Canceled) || + errors.As(currentError, context.DeadlineExceeded) { + SendErrorMessage( + bot, ctx, + "download request canceled or timed out", + ) + return + } + for currentError != nil { var botError *util.Error if errors.As(currentError, &botError) { diff --git a/bot/handlers/help.go b/bot/handlers/help.go index 4e88583..c373f94 100644 --- a/bot/handlers/help.go +++ b/bot/handlers/help.go @@ -9,7 +9,9 @@ var helpMessage = "usage:\n" + "- you can add the bot to a group " + "to start catching sent links\n" + "- you can send a link to the bot privately " + - "to download the media too\n\n" + + "to download the media too\n" + + "- you can use inline mode " + + "to download media from any chat\n\n" + "group commands:\n" + "- /settings = show current settings\n" + "- /captions (true|false) = enable/disable descriptions\n" + diff --git a/bot/handlers/inline.go b/bot/handlers/inline.go index 3140491..cfc6d1d 100644 --- a/bot/handlers/inline.go +++ b/bot/handlers/inline.go @@ -41,11 +41,12 @@ func InlineDownloadResultHandler( bot *gotgbot.Bot, ctx *ext.Context, ) error { - dlCtx, ok := core.InlineTasks[ctx.ChosenInlineResult.ResultId] + taskID := ctx.ChosenInlineResult.ResultId + dlCtx, ok := core.GetTask(taskID) if !ok { return nil } - defer delete(core.InlineTasks, ctx.ChosenInlineResult.ResultId) + defer core.DeleteTask(taskID) mediaChan := make(chan *models.Media, 1) errChan := make(chan error, 1) diff --git a/models/ctx.go b/models/ctx.go index 7b2d164..63227af 100644 --- a/models/ctx.go +++ b/models/ctx.go @@ -1,6 +1,9 @@ package models +import "context" + type DownloadContext struct { + Context context.Context MatchedContentID string MatchedContentURL string MatchedGroups map[string]string diff --git a/models/download.go b/models/download.go index 1c5f772..ca606f2 100644 --- a/models/download.go +++ b/models/download.go @@ -11,4 +11,5 @@ type DownloadConfig struct { RetryDelay time.Duration // delay between retries Remux bool // whether to remux the downloaded file with ffmpeg ProgressUpdater func(float64) // optional function to report download progress + MaxInMemory int // maximum file size for in-memory downloads } diff --git a/util/download.go b/util/download.go index e6e6e54..8fc6d1b 100644 --- a/util/download.go +++ b/util/download.go @@ -29,6 +29,7 @@ func DefaultConfig() *models.DownloadConfig { RetryAttempts: 3, RetryDelay: 2 * time.Second, Remux: true, + MaxInMemory: 50 * 1024 * 1024, // 50MB } } @@ -127,7 +128,10 @@ func DownloadFileInMemory( case <-ctx.Done(): return nil, ctx.Err() default: - data, err := downloadInMemory(ctx, fileURL, config.Timeout) + data, err := downloadInMemory( + ctx, fileURL, + config, + ) if err != nil { errs = append(errs, err) continue @@ -142,9 +146,12 @@ func DownloadFileInMemory( func downloadInMemory( ctx context.Context, fileURL string, - timeout time.Duration, + config *models.DownloadConfig, ) ([]byte, error) { - reqCtx, cancel := context.WithTimeout(ctx, timeout) + reqCtx, cancel := context.WithTimeout( + ctx, + config.Timeout, + ) defer cancel() select { @@ -169,6 +176,10 @@ func downloadInMemory( return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } + if resp.ContentLength > int64(config.MaxInMemory) { + return nil, fmt.Errorf("file too large for in-memory download: %d bytes", resp.ContentLength) + } + var buf bytes.Buffer if resp.ContentLength > 0 { buf.Grow(int(resp.ContentLength)) diff --git a/util/http.go b/util/http.go index 8c21c8c..2cb86d7 100644 --- a/util/http.go +++ b/util/http.go @@ -25,13 +25,15 @@ func GetHTTPSession() *http.Client { IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, - MaxIdleConnsPerHost: 10, - MaxConnsPerHost: 10, + MaxIdleConnsPerHost: 20, + MaxConnsPerHost: 20, + ResponseHeaderTimeout: 30 * time.Second, + DisableCompression: false, } httpSession = &http.Client{ Transport: transport, - Timeout: 30 * time.Second, + Timeout: 60 * time.Second, } }) return httpSession diff --git a/util/img.go b/util/img.go index 69db8b1..339c83b 100644 --- a/util/img.go +++ b/util/img.go @@ -105,9 +105,6 @@ func DetectImageFormat(file io.ReadSeeker) (string, error) { } func isHEIF(header []byte) bool { - if len(header) < 12 { - return false - } isHeifHeader := header[0] == 0x00 && header[1] == 0x00 && header[2] == 0x00 && (header[3] == 0x18 || header[3] == 0x1C) && bytes.Equal(header[4:8], []byte("ftyp"))