-
Notifications
You must be signed in to change notification settings - Fork 483
Expand file tree
/
Copy pathrequests.py
More file actions
321 lines (247 loc) · 11.5 KB
/
requests.py
File metadata and controls
321 lines (247 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
from __future__ import annotations
import abc
import logging
import time
import types
# Allow some request objects to be imported from here instead of requests
import warnings
from datetime import datetime, timedelta
from email.message import EmailMessage
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from urllib.request import urlopen
import requests
from loguru import logger
from requests import RequestException
from flexget import __version__ as version
from flexget.utils.tools import TimedDict, parse_timedelta
# If we use just 'requests' here, we'll get the logger created by requests, rather than our own
logger = logger.bind(name='utils.requests')
# Don't emit info level urllib3 log messages or below
logging.getLogger('requests.packages.urllib3').setLevel(logging.WARNING)
# same as above, but for systems where urllib3 isn't part of the requests package (i.e., Ubuntu)
logging.getLogger('urllib3').setLevel(logging.WARNING)
# Time to wait before trying an unresponsive site again
WAIT_TIME = timedelta(seconds=60)
# Remembers sites that have timed out
unresponsive_hosts = TimedDict(WAIT_TIME)
if TYPE_CHECKING:
from collections.abc import Mapping
from typing import TypedDict
class StateCacheDict(TypedDict):
tokens: float | int
last_update: datetime
def is_unresponsive(url: str) -> bool:
"""Check if host of given url has timed out within WAIT_TIME.
:param url: The url to check
:return: True if the host has timed out within WAIT_TIME
:rtype: bool
"""
host = urlparse(url).hostname
return host in unresponsive_hosts
def set_unresponsive(url: str) -> None:
"""Mark the host of a given url as unresponsive.
:param url: The url that timed out
"""
host = urlparse(url).hostname
if host in unresponsive_hosts:
# If somehow this is called again before previous timer clears, don't refresh
return
unresponsive_hosts[host] = True
class DomainLimiter(abc.ABC):
def __init__(self, domain: str) -> None:
self.domain = domain
@abc.abstractmethod
def __call__(self) -> None:
"""Be called once before every request to the domain."""
class TokenBucketLimiter(DomainLimiter):
"""A token bucket rate limiter for domains.
New instances for the same domain will restore previous values.
"""
# This is just an in memory cache right now, it works for the daemon, and across tasks in a single execution
# but not for multiple executions via cron. Do we need to store this to db?
state_cache: dict[str, StateCacheDict] = {}
def __init__(
self,
domain: str,
tokens: float,
rate: str | timedelta,
wait: bool = True,
) -> None:
"""Init a token bucket rate limiter.
:param int tokens: Size of bucket
:param rate: Amount of time to accrue 1 token. Either `timedelta` or interval string.
:param bool wait: If true, will wait for a token to be available. If false, errors when token is not available.
"""
super().__init__(domain)
self.max_tokens = tokens
self.rate = parse_timedelta(rate)
self.wait = wait
# Restore previous state for this domain, or establish new state cache
self.state = self.state_cache.setdefault(
domain, {'tokens': self.max_tokens, 'last_update': datetime.now()}
)
@property
def tokens(self) -> float | int:
return min(self.max_tokens, self.state['tokens'])
@tokens.setter
def tokens(self, value: float) -> None:
self.state['tokens'] = value
@property
def last_update(self) -> datetime:
return self.state['last_update']
@last_update.setter
def last_update(self, value: datetime) -> None:
self.state['last_update'] = value
def __call__(self) -> None:
if self.tokens < self.max_tokens:
regen = (datetime.now() - self.last_update).total_seconds() / self.rate.total_seconds()
self.tokens += regen
self.last_update = datetime.now()
if self.tokens < 1:
if not self.wait:
raise RequestException(f'Requests to {self.domain} have exceeded their limit.')
wait = self.rate.total_seconds() * (1 - self.tokens)
# Don't spam console if wait is low
level = 'DEBUG' if wait < 4 else 'VERBOSE'
logger.log(level, 'Waiting {:.2f} seconds until next request to {}', wait, self.domain)
# Sleep until it is time for the next request
time.sleep(wait)
self.tokens -= 1
class TimedLimiter(TokenBucketLimiter):
"""Enforces a minimum interval between requests to a given domain."""
def __init__(self, domain: str, interval: str | timedelta) -> None:
super().__init__(domain, 1, interval)
def _wrap_urlopen(url: str, timeout: int | None = None) -> requests.Response:
"""Handle alternate schemes using urllib, wrap the response in a requests.Response.
This is not installed as an adapter in requests, since urls without network locations
(e.g. file:///somewhere) will cause errors
"""
try:
raw = urlopen(url, timeout=timeout)
except OSError as e:
msg = f'Error getting {url}: {e}'
logger.error(msg)
raise RequestException(msg)
resp = requests.Response()
resp.raw = raw
# requests passes the `decode_content` kwarg to read
orig_read = raw.read
resp.raw.read = lambda size, **kwargs: orig_read(size)
resp.status_code = raw.code or 200
resp.headers = requests.structures.CaseInsensitiveDict(raw.headers)
if url.startswith('file://'):
def close(self):
self.raw.close()
resp.close = types.MethodType(close, resp)
return resp
def limit_domains(url: str, limit_dict: dict[str, DomainLimiter]) -> None:
"""If this url matches a domain in `limit_dict`, run the limiter.
This is separated in to its own function so that limits can be disabled during unit tests with VCR.
"""
for domain, limiter in limit_dict.items():
if domain in url:
limiter()
break
def parse_header(header: str) -> tuple[str, Mapping]:
"""Parse a MIME header (such as Content-Type) into a main value and a dictionary of parameters.
Replaces function in the deprecated cgi stdlib module.
"""
msg = EmailMessage()
msg['content-type'] = header
return msg.get_content_type(), msg['content-type'].params
class Session(requests.Session):
"""Subclass of requests Session class which defines some of our own defaults, records unresponsive sites, and raises errors by default."""
def __init__(self, timeout: int = 30, max_retries: int = 1, **kwargs) -> None:
"""Set some defaults for our session if not explicitly defined."""
super().__init__()
self.timeout = timeout
self.adapters['http://'].max_retries = max_retries
# Stores min intervals between requests for certain sites
self.domain_limiters: dict[str, DomainLimiter] = {}
self.headers.update({'User-Agent': f'FlexGet/{version} (www.flexget.com)'})
def add_cookiejar(self, cookiejar):
"""Merge cookies from `cookiejar` into cookiejar for this session.
:param cookiejar: CookieJar instance to add to the session.
"""
for cookie in cookiejar:
self.cookies.set_cookie(cookie)
def set_domain_delay(self, domain, delay):
"""Do not use this anymore as it is DEPRECATED. Use `add_domain_limiter`.
Register a minimum interval between requests to `domain`
:param domain: The domain to set the interval on
:param delay: The amount of time between requests, can be a timedelta or string like '3 seconds'
"""
warnings.warn(
'set_domain_delay is deprecated, use add_domain_limiter',
DeprecationWarning,
stacklevel=2,
)
self.domain_limiters[domain] = TimedLimiter(domain, delay)
def add_domain_limiter(self, limiter: DomainLimiter, replace: bool = True) -> None:
"""Add a limiter to throttle requests to a specific domain.
:param DomainLimiter limiter: The `DomainLimiter` to add to the session.
:param replace: If `True`, an existing domain limiter for this domain will be replaced.
If `False`, no changes will be made.
"""
if limiter.domain in self.domain_limiters and not replace:
return
self.domain_limiters[limiter.domain] = limiter
def request(self, method: str, url: str, *args, **kwargs) -> requests.Response:
"""Do a request, but raise Timeout immediately if site is known to timeout, and record sites that timeout.
Also raises errors getting the content by default.
:param bool raise_status: If True, non-success status code responses will be raised as errors (True by default)
:param disable_limiters: If True, any limiters configured for this session will be ignored for this request.
"""
# Raise Timeout right away if site is known to timeout
if is_unresponsive(url):
raise requests.Timeout(
f'Requests to this site ({urlparse(url).hostname}) have timed out recently. '
'Waiting before trying again.'
)
# Run domain limiters for this url
if not kwargs.pop('disable_limiters', False):
limit_domains(url, self.domain_limiters)
kwargs.setdefault('timeout', self.timeout)
raise_status = kwargs.pop('raise_status', True)
# If we do not have an adapter for this url, pass it off to urllib
if not any(url.startswith(adapter) for adapter in self.adapters):
logger.debug('No adaptor, passing off to urllib')
return _wrap_urlopen(url, timeout=kwargs['timeout'])
try:
logger.debug(
'{}ing URL {} with args {} and kwargs {}', method.upper(), url, args, kwargs
)
result = super().request(method, url, *args, **kwargs)
except requests.Timeout:
# Mark this site in known unresponsive list
set_unresponsive(url)
raise
if raise_status:
result.raise_for_status()
return result
# Define some module level functions that use our Session, so this module can be used like main requests module
def request(method: str, url: str, **kwargs) -> requests.Response:
s = kwargs.pop('session', Session())
return s.request(method=method, url=url, **kwargs)
def head(url: str, **kwargs) -> requests.Response:
"""Send a HEAD request. Return :class:`Response` object.
:param url: URL for the new :class:`Request` object.
:param kwargs: Optional arguments that ``request`` takes.
"""
kwargs.setdefault('allow_redirects', True)
return request('head', url, **kwargs)
def get(url: str, **kwargs) -> requests.Response:
"""Send a GET request. Return :class:`Response` object.
:param url: URL for the new :class:`Request` object.
:param kwargs: Optional arguments that ``request`` takes.
"""
kwargs.setdefault('allow_redirects', True)
return request('get', url, **kwargs)
def post(url: str, data=None, **kwargs) -> requests.Response:
"""Send a POST request. Return :class:`Response` object.
:param url: URL for the new :class:`Request` object.
:param data: (optional) Dictionary or bytes to send in the body of the :class:`Request`.
:param kwargs: Optional arguments that ``request`` takes.
"""
return request('post', url, data=data, **kwargs)