cleanup code

This commit is contained in:
stefanodvx 2025-04-15 10:46:54 +02:00
parent 26452d0f53
commit c06c7958e2
10 changed files with 64 additions and 14 deletions

View file

@ -3,6 +3,7 @@ package core
import ( import (
"fmt" "fmt"
"log" "log"
"sync"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -16,7 +17,23 @@ import (
"github.com/PaulSonOfLars/gotgbot/v2/ext" "github.com/PaulSonOfLars/gotgbot/v2/ext"
) )
var InlineTasks = make(map[string]*models.DownloadContext) var InlineTasks sync.Map
func GetTask(id string) (*models.DownloadContext, bool) {
value, ok := InlineTasks.Load(id)
if !ok {
return nil, false
}
return value.(*models.DownloadContext), true
}
func SetTask(id string, task *models.DownloadContext) {
InlineTasks.Store(id, task)
}
func DeleteTask(id string) {
InlineTasks.Delete(id)
}
func HandleInline( func HandleInline(
bot *gotgbot.Bot, bot *gotgbot.Bot,
@ -183,7 +200,7 @@ func StartInlineTask(
log.Println("failed to answer inline query") log.Println("failed to answer inline query")
return nil return nil
} }
InlineTasks[taskID] = dlCtx SetTask(taskID, dlCtx)
return nil return nil
} }

View file

@ -1,6 +1,7 @@
package core package core
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"slices" "slices"
@ -21,6 +22,11 @@ func HandleDownloadRequest(
ctx *ext.Context, ctx *ext.Context,
dlCtx *models.DownloadContext, dlCtx *models.DownloadContext,
) error { ) error {
taskCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
dlCtx.Context = taskCtx
chatID := ctx.EffectiveMessage.Chat.Id chatID := ctx.EffectiveMessage.Chat.Id
if dlCtx.Extractor.Type == enums.ExtractorTypeSingle { if dlCtx.Extractor.Type == enums.ExtractorTypeSingle {
TypingEffect(bot, ctx, chatID) TypingEffect(bot, ctx, chatID)

View file

@ -198,6 +198,16 @@ func HandleErrorMessage(
err error, err error,
) { ) {
currentError := err currentError := err
if errors.As(currentError, context.Canceled) ||
errors.As(currentError, context.DeadlineExceeded) {
SendErrorMessage(
bot, ctx,
"download request canceled or timed out",
)
return
}
for currentError != nil { for currentError != nil {
var botError *util.Error var botError *util.Error
if errors.As(currentError, &botError) { if errors.As(currentError, &botError) {

View file

@ -9,7 +9,9 @@ var helpMessage = "usage:\n" +
"- you can add the bot to a group " + "- you can add the bot to a group " +
"to start catching sent links\n" + "to start catching sent links\n" +
"- you can send a link to the bot privately " + "- you can send a link to the bot privately " +
"to download the media too\n\n" + "to download the media too\n" +
"- you can use inline mode " +
"to download media from any chat\n\n" +
"group commands:\n" + "group commands:\n" +
"- /settings = show current settings\n" + "- /settings = show current settings\n" +
"- /captions (true|false) = enable/disable descriptions\n" + "- /captions (true|false) = enable/disable descriptions\n" +

View file

@ -41,11 +41,12 @@ func InlineDownloadResultHandler(
bot *gotgbot.Bot, bot *gotgbot.Bot,
ctx *ext.Context, ctx *ext.Context,
) error { ) error {
dlCtx, ok := core.InlineTasks[ctx.ChosenInlineResult.ResultId] taskID := ctx.ChosenInlineResult.ResultId
dlCtx, ok := core.GetTask(taskID)
if !ok { if !ok {
return nil return nil
} }
defer delete(core.InlineTasks, ctx.ChosenInlineResult.ResultId) defer core.DeleteTask(taskID)
mediaChan := make(chan *models.Media, 1) mediaChan := make(chan *models.Media, 1)
errChan := make(chan error, 1) errChan := make(chan error, 1)

View file

@ -1,6 +1,9 @@
package models package models
import "context"
type DownloadContext struct { type DownloadContext struct {
Context context.Context
MatchedContentID string MatchedContentID string
MatchedContentURL string MatchedContentURL string
MatchedGroups map[string]string MatchedGroups map[string]string

View file

@ -11,4 +11,5 @@ type DownloadConfig struct {
RetryDelay time.Duration // delay between retries RetryDelay time.Duration // delay between retries
Remux bool // whether to remux the downloaded file with ffmpeg Remux bool // whether to remux the downloaded file with ffmpeg
ProgressUpdater func(float64) // optional function to report download progress ProgressUpdater func(float64) // optional function to report download progress
MaxInMemory int // maximum file size for in-memory downloads
} }

View file

@ -29,6 +29,7 @@ func DefaultConfig() *models.DownloadConfig {
RetryAttempts: 3, RetryAttempts: 3,
RetryDelay: 2 * time.Second, RetryDelay: 2 * time.Second,
Remux: true, Remux: true,
MaxInMemory: 50 * 1024 * 1024, // 50MB
} }
} }
@ -127,7 +128,10 @@ func DownloadFileInMemory(
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
default: default:
data, err := downloadInMemory(ctx, fileURL, config.Timeout) data, err := downloadInMemory(
ctx, fileURL,
config,
)
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
continue continue
@ -142,9 +146,12 @@ func DownloadFileInMemory(
func downloadInMemory( func downloadInMemory(
ctx context.Context, ctx context.Context,
fileURL string, fileURL string,
timeout time.Duration, config *models.DownloadConfig,
) ([]byte, error) { ) ([]byte, error) {
reqCtx, cancel := context.WithTimeout(ctx, timeout) reqCtx, cancel := context.WithTimeout(
ctx,
config.Timeout,
)
defer cancel() defer cancel()
select { select {
@ -169,6 +176,10 @@ func downloadInMemory(
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
} }
if resp.ContentLength > int64(config.MaxInMemory) {
return nil, fmt.Errorf("file too large for in-memory download: %d bytes", resp.ContentLength)
}
var buf bytes.Buffer var buf bytes.Buffer
if resp.ContentLength > 0 { if resp.ContentLength > 0 {
buf.Grow(int(resp.ContentLength)) buf.Grow(int(resp.ContentLength))

View file

@ -25,13 +25,15 @@ func GetHTTPSession() *http.Client {
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
MaxIdleConnsPerHost: 10, MaxIdleConnsPerHost: 20,
MaxConnsPerHost: 10, MaxConnsPerHost: 20,
ResponseHeaderTimeout: 30 * time.Second,
DisableCompression: false,
} }
httpSession = &http.Client{ httpSession = &http.Client{
Transport: transport, Transport: transport,
Timeout: 30 * time.Second, Timeout: 60 * time.Second,
} }
}) })
return httpSession return httpSession

View file

@ -105,9 +105,6 @@ func DetectImageFormat(file io.ReadSeeker) (string, error) {
} }
func isHEIF(header []byte) bool { func isHEIF(header []byte) bool {
if len(header) < 12 {
return false
}
isHeifHeader := header[0] == 0x00 && header[1] == 0x00 && isHeifHeader := header[0] == 0x00 && header[1] == 0x00 &&
header[2] == 0x00 && (header[3] == 0x18 || header[3] == 0x1C) && header[2] == 0x00 && (header[3] == 0x18 || header[3] == 0x1C) &&
bytes.Equal(header[4:8], []byte("ftyp")) bytes.Equal(header[4:8], []byte("ftyp"))