diff --git a/ext/instagram/main.go b/ext/instagram/main.go index e911fe3..3c2a7fd 100644 --- a/ext/instagram/main.go +++ b/ext/instagram/main.go @@ -73,7 +73,7 @@ var ShareURLExtractor = &models.Extractor{ IsRedirect: true, Run: func(ctx *models.DownloadContext) (*models.ExtractorResponse, error) { - client := util.GetHTTPSession(ctx.Extractor.CodeName) + client := util.GetHTTPClient(ctx.Extractor.CodeName) req, err := http.NewRequest( http.MethodGet, ctx.MatchedContentURL, @@ -98,7 +98,7 @@ func MediaListFromAPI( ctx *models.DownloadContext, stories bool, ) ([]*models.Media, error) { - client := util.GetHTTPSession(ctx.Extractor.CodeName) + client := util.GetHTTPClient(ctx.Extractor.CodeName) var mediaList []*models.Media postURL := ctx.MatchedContentURL diff --git a/ext/pinterest/main.go b/ext/pinterest/main.go index b7d6612..e496632 100644 --- a/ext/pinterest/main.go +++ b/ext/pinterest/main.go @@ -86,7 +86,7 @@ var Extractor = &models.Extractor{ func ExtractPinMedia(ctx *models.DownloadContext) ([]*models.Media, error) { pinID := ctx.MatchedContentID contentURL := ctx.MatchedContentURL - client := util.GetHTTPSession(ctx.Extractor.CodeName) + client := util.GetHTTPClient(ctx.Extractor.CodeName) pinData, err := GetPinData(client, pinID) if err != nil { diff --git a/ext/reddit/main.go b/ext/reddit/main.go index 938de6f..a3de250 100644 --- a/ext/reddit/main.go +++ b/ext/reddit/main.go @@ -31,7 +31,7 @@ var ShortExtractor = &models.Extractor{ IsRedirect: true, Run: func(ctx *models.DownloadContext) (*models.ExtractorResponse, error) { - client := util.GetHTTPSession(ctx.Extractor.CodeName) + client := util.GetHTTPClient(ctx.Extractor.CodeName) req, err := http.NewRequest(http.MethodGet, ctx.MatchedContentURL, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) @@ -80,7 +80,7 @@ var Extractor = &models.Extractor{ } func MediaListFromAPI(ctx *models.DownloadContext) ([]*models.Media, error) { - client := util.GetHTTPSession(ctx.Extractor.CodeName) + client := util.GetHTTPClient(ctx.Extractor.CodeName) host := ctx.MatchedGroups["host"] slug := ctx.MatchedGroups["slug"] diff --git a/ext/redgifs/main.go b/ext/redgifs/main.go index 30b3e17..3af2aeb 100644 --- a/ext/redgifs/main.go +++ b/ext/redgifs/main.go @@ -49,7 +49,7 @@ var Extractor = &models.Extractor{ func MediaListFromAPI(ctx *models.DownloadContext) ([]*models.Media, error) { var mediaList []*models.Media - client := util.GetHTTPSession(ctx.Extractor.CodeName) + client := util.GetHTTPClient(ctx.Extractor.CodeName) response, err := GetVideo( client, ctx.MatchedContentID) diff --git a/ext/tiktok/main.go b/ext/tiktok/main.go index 0d802d4..b934f56 100644 --- a/ext/tiktok/main.go +++ b/ext/tiktok/main.go @@ -76,7 +76,7 @@ var Extractor = &models.Extractor{ } func MediaListFromAPI(ctx *models.DownloadContext) ([]*models.Media, error) { - client := util.GetHTTPSession(ctx.Extractor.CodeName) + client := util.GetHTTPClient(ctx.Extractor.CodeName) var mediaList []*models.Media details, err := GetVideoAPI( diff --git a/ext/twitter/main.go b/ext/twitter/main.go index 2dceefd..306d9bc 100644 --- a/ext/twitter/main.go +++ b/ext/twitter/main.go @@ -28,7 +28,7 @@ var ShortExtractor = &models.Extractor{ IsRedirect: true, Run: func(ctx *models.DownloadContext) (*models.ExtractorResponse, error) { - client := util.GetHTTPSession(ctx.Extractor.CodeName) + client := util.GetHTTPClient(ctx.Extractor.CodeName) req, err := http.NewRequest(http.MethodGet, ctx.MatchedContentURL, nil) if err != nil { return nil, fmt.Errorf("failed to create req: %w", err) @@ -79,7 +79,7 @@ var Extractor = &models.Extractor{ func MediaListFromAPI(ctx *models.DownloadContext) ([]*models.Media, error) { var mediaList []*models.Media - client := util.GetHTTPSession(ctx.Extractor.CodeName) + client := util.GetHTTPClient(ctx.Extractor.CodeName) tweetData, err := GetTweetAPI( client, ctx.MatchedContentID) diff --git a/util/download.go b/util/download.go index 091cf93..aa9e620 100644 --- a/util/download.go +++ b/util/download.go @@ -21,7 +21,7 @@ import ( "github.com/google/uuid" ) -var downloadHTTPSession = GetDefaultHTTPSession() +var downloadHTTPSession = GetDefaultHTTPClient() func DefaultConfig() *models.DownloadConfig { downloadsDir := os.Getenv("DOWNLOADS_DIR") diff --git a/util/http.go b/util/http.go index e15fd57..a1d8b2f 100644 --- a/util/http.go +++ b/util/http.go @@ -16,30 +16,23 @@ import ( "time" ) -type EdgeProxyClient struct { - *http.Client - - proxyURL string -} - var ( - httpSession *http.Client - httpSessionOnce sync.Once - - extractorsHttpSession = make(map[string]models.HTTPClient) + defaultClient *http.Client + defaultClientOnce sync.Once + extractorClients = make(map[string]models.HTTPClient) ) -func GetDefaultHTTPSession() *http.Client { - httpSessionOnce.Do(func() { - httpSession = &http.Client{ - Transport: GetBaseTransport(), +func GetDefaultHTTPClient() *http.Client { + defaultClientOnce.Do(func() { + defaultClient = &http.Client{ + Transport: createBaseTransport(), Timeout: 60 * time.Second, } }) - return httpSession + return defaultClient } -func GetBaseTransport() *http.Transport { +func createBaseTransport() *http.Transport { return &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ @@ -58,112 +51,137 @@ func GetBaseTransport() *http.Transport { } } -func GetHTTPSession(extractor string) models.HTTPClient { - if client, ok := extractorsHttpSession[extractor]; ok { +func GetHTTPClient(extractor string) models.HTTPClient { + if client, exists := extractorClients[extractor]; exists { return client } cfg := config.GetExtractorConfig(extractor) if cfg == nil { - return GetDefaultHTTPSession() + return GetDefaultHTTPClient() } + var client models.HTTPClient + if cfg.EdgeProxyURL != "" { - client := GetEdgeProxyClient(cfg.EdgeProxyURL) - extractorsHttpSession[extractor] = client - return client + client = NewEdgeProxyClient(cfg.EdgeProxyURL) + } else { + client = createClientWithProxy(cfg) } - transport := GetBaseTransport() - client := &http.Client{ + extractorClients[extractor] = client + return client +} + +func createClientWithProxy(cfg *models.ExtractorConfig) *http.Client { + transport := createBaseTransport() + + if cfg.HTTPProxy != "" || cfg.HTTPSProxy != "" { + configureProxyTransport(transport, cfg) + } + + return &http.Client{ Transport: transport, Timeout: 60 * time.Second, } +} - if cfg.HTTPProxy == "" && cfg.HTTPSProxy == "" { - extractorsHttpSession[extractor] = client - return client - } - +func configureProxyTransport( + transport *http.Transport, + cfg *models.ExtractorConfig, +) { var httpProxyURL, httpsProxyURL *url.URL var err error if cfg.HTTPProxy != "" { - if httpProxyURL, err = url.Parse(cfg.HTTPProxy); err != nil { + httpProxyURL, err = url.Parse(cfg.HTTPProxy) + if err != nil { log.Printf("warning: invalid HTTP proxy URL '%s': %v\n", cfg.HTTPProxy, err) } } if cfg.HTTPSProxy != "" { - if httpsProxyURL, err = url.Parse(cfg.HTTPSProxy); err != nil { + httpsProxyURL, err = url.Parse(cfg.HTTPSProxy) + if err != nil { log.Printf("warning: invalid HTTPS proxy URL '%s': %v\n", cfg.HTTPSProxy, err) } } - if httpProxyURL != nil || httpsProxyURL != nil { - noProxyList := strings.Split(cfg.NoProxy, ",") - for i := range noProxyList { - noProxyList[i] = strings.TrimSpace(noProxyList[i]) - } - - transport.Proxy = func(req *http.Request) (*url.URL, error) { - if cfg.NoProxy != "" { - host := req.URL.Hostname() - for _, p := range noProxyList { - if p == "" { - continue - } - if p == host || (strings.HasPrefix(p, ".") && strings.HasSuffix(host, p)) { - return nil, nil - } - } - } - if req.URL.Scheme == "https" && httpsProxyURL != nil { - return httpsProxyURL, nil - } - if req.URL.Scheme == "http" && httpProxyURL != nil { - return httpProxyURL, nil - } - if httpsProxyURL != nil { - return httpsProxyURL, nil - } - return httpProxyURL, nil - } + if httpProxyURL == nil && httpsProxyURL == nil { + return } - extractorsHttpSession[extractor] = client - return client + noProxyList := parseNoProxyList(cfg.NoProxy) + + transport.Proxy = func(req *http.Request) (*url.URL, error) { + if shouldBypassProxy(req.URL.Hostname(), noProxyList) { + return nil, nil + } + + if req.URL.Scheme == "https" && httpsProxyURL != nil { + return httpsProxyURL, nil + } + if req.URL.Scheme == "http" && httpProxyURL != nil { + return httpProxyURL, nil + } + if httpsProxyURL != nil { + return httpsProxyURL, nil + } + return httpProxyURL, nil + } } -func GetEdgeProxyClient(proxyURL string) *EdgeProxyClient { - edgeProxyClient := &EdgeProxyClient{ - Client: &http.Client{ - Transport: GetBaseTransport(), +func parseNoProxyList(noProxy string) []string { + if noProxy == "" { + return nil + } + + list := strings.Split(noProxy, ",") + for i := range list { + list[i] = strings.TrimSpace(list[i]) + } + return list +} + +func shouldBypassProxy(host string, noProxyList []string) bool { + for _, p := range noProxyList { + if p == "" { + continue + } + if p == host || (strings.HasPrefix(p, ".") && strings.HasSuffix(host, p)) { + return true + } + } + return false +} + +type EdgeProxyClient struct { + client *http.Client + proxyURL string +} + +func NewEdgeProxyClient(proxyURL string) *EdgeProxyClient { + return &EdgeProxyClient{ + client: &http.Client{ + Transport: createBaseTransport(), Timeout: 60 * time.Second, }, proxyURL: proxyURL, } - return edgeProxyClient } func (c *EdgeProxyClient) Do(req *http.Request) (*http.Response, error) { if c.proxyURL == "" { return nil, fmt.Errorf("proxy URL is not set") } + targetURL := req.URL.String() encodedURL := url.QueryEscape(targetURL) proxyURLWithParam := c.proxyURL + "?url=" + encodedURL - var bodyBytes []byte - var err error - - if req.Body != nil { - bodyBytes, err = io.ReadAll(req.Body) - if err != nil { - return nil, fmt.Errorf("error reading request body: %w", err) - } - req.Body.Close() - req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + bodyBytes, err := readRequestBody(req) + if err != nil { + return nil, err } proxyReq, err := http.NewRequest( @@ -175,18 +193,42 @@ func (c *EdgeProxyClient) Do(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("error creating proxy request: %w", err) } - for name, values := range req.Header { - for _, value := range values { - proxyReq.Header.Add(name, value) - } - } + copyHeaders(req.Header, proxyReq.Header) - proxyResp, err := c.Client.Do(proxyReq) + proxyResp, err := c.client.Do(proxyReq) if err != nil { return nil, fmt.Errorf("proxy request failed: %w", err) } defer proxyResp.Body.Close() + return parseProxyResponse(proxyResp, req) +} + +func readRequestBody(req *http.Request) ([]byte, error) { + if req.Body == nil { + return nil, nil + } + + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("error reading request body: %w", err) + } + + req.Body.Close() + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + return bodyBytes, nil +} + +func copyHeaders(source, destination http.Header) { + for name, values := range source { + for _, value := range values { + destination.Add(name, value) + } + } +} + +func parseProxyResponse(proxyResp *http.Response, originalReq *http.Request) (*http.Response, error) { body, err := io.ReadAll(proxyResp.Body) if err != nil { return nil, fmt.Errorf("error reading proxy response: %w", err) @@ -202,8 +244,9 @@ func (c *EdgeProxyClient) Do(req *http.Request) (*http.Response, error) { Status: fmt.Sprintf("%d %s", response.StatusCode, http.StatusText(response.StatusCode)), Body: io.NopCloser(bytes.NewBufferString(response.Text)), Header: make(http.Header), - Request: req, + Request: originalReq, } + parsedResponseURL, err := url.Parse(response.URL) if err != nil { return nil, fmt.Errorf("error parsing response URL: %w", err) diff --git a/util/misc.go b/util/misc.go index 5fdbb92..d6b2110 100644 --- a/util/misc.go +++ b/util/misc.go @@ -29,7 +29,7 @@ func GetLocationURL( userAgent = ChromeUA } req.Header.Set("User-Agent", userAgent) - session := GetDefaultHTTPSession() + session := GetDefaultHTTPClient() resp, err := session.Do(req) if err != nil { return "", fmt.Errorf("failed to send request: %w", err)