Skip to content

Commit 61dcf49

Browse files
authored
Allow for enterprise base url prefixed with api. (#3419)
Supports enterprise base urls like `api.SUBDOMAIN.DOMAIN.TLD` to fetch content from `SUBDOMAIN.DOMAIN.TLD`, while not allowing access to `DOMAIN.TLD`. Fixes #3398. Supersedes #3399.
1 parent ae23d60 commit 61dcf49

2 files changed

Lines changed: 44 additions & 10 deletions

File tree

github/Requester.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,10 @@ def __init__(
409409
self.__graphql_prefix = self.get_graphql_prefix(o.path)
410410
self.__graphql_url = urllib.parse.urlunparse(o._replace(path=self.__graphql_prefix))
411411
self.__hostname = o.hostname # type: ignore
412+
if base_url == Consts.DEFAULT_BASE_URL:
413+
self.__domains = ["github.com", "githubusercontent.com"]
414+
else:
415+
self.__domains = list({o.hostname, o.hostname.removeprefix("api.")}) # type: ignore
412416
self.__port = o.port
413417
self.__prefix = o.path
414418
self.__timeout = timeout
@@ -848,7 +852,7 @@ def __check(
848852
return responseHeaders, data
849853

850854
@classmethod
851-
def __hostnameHasDomain(cls, hostname: str, domain_or_domains: str | tuple[str, ...]) -> bool:
855+
def __hostnameHasDomain(cls, hostname: str, domain_or_domains: str | list[str]) -> bool:
852856
if isinstance(domain_or_domains, str):
853857
if hostname == domain_or_domains:
854858
return True
@@ -864,10 +868,7 @@ def __assertUrlAllowed(self, url: str) -> None:
864868
assert o.path.startswith(tuple(prefixes)), o.path
865869
assert o.port == self.__port, o.port
866870
else:
867-
if self.__base_url == Consts.DEFAULT_BASE_URL:
868-
assert self.__hostnameHasDomain(o.hostname, ("github.com", "githubusercontent.com")), o.hostname
869-
else:
870-
assert self.__hostnameHasDomain(o.hostname, self.__hostname), o.hostname
871+
assert self.__hostnameHasDomain(o.hostname, self.__domains), o.hostname
871872

872873
def __customConnection(self, url: str) -> HTTPRequestsConnectionClass | HTTPSRequestsConnectionClass | None:
873874
cnx: HTTPRequestsConnectionClass | HTTPSRequestsConnectionClass | None = None

tests/Requester.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,18 +271,19 @@ def testHostnameHasDomain(self):
271271
assert self.g.requester.__hostnameHasDomain("ghe.local", "ghe.local")
272272
assert self.g.requester.__hostnameHasDomain("api.ghe.local", "ghe.local")
273273
assert self.g.requester.__hostnameHasDomain("api.prod.ghe.local", "prod.ghe.local")
274-
assert self.g.requester.__hostnameHasDomain("github.com", ("github.com", "githubusercontent.com"))
275-
assert self.g.requester.__hostnameHasDomain("api.github.com", ("github.com", "githubusercontent.com"))
276-
assert self.g.requester.__hostnameHasDomain("githubusercontent.com", ("github.com", "githubusercontent.com"))
274+
assert self.g.requester.__hostnameHasDomain("github.com", ["github.com", "githubusercontent.com"])
275+
assert self.g.requester.__hostnameHasDomain("api.github.com", ["github.com", "githubusercontent.com"])
276+
assert self.g.requester.__hostnameHasDomain("githubusercontent.com", ["github.com", "githubusercontent.com"])
277277
assert self.g.requester.__hostnameHasDomain(
278-
"objects.githubusercontent.com", ("github.com", "githubusercontent.com")
278+
"objects.githubusercontent.com", ["github.com", "githubusercontent.com"]
279279
)
280280
assert self.g.requester.__hostnameHasDomain("maliciousgithub.com", "github.com") is False
281-
assert self.g.requester.__hostnameHasDomain("abc.def", ("github.com", "githubusercontent.com")) is False
281+
assert self.g.requester.__hostnameHasDomain("abc.def", ["github.com", "githubusercontent.com"]) is False
282282

283283
def testAssertUrlAllowed(self):
284284
# default github.com requester
285285
requester = self.g.requester
286+
self.assertEqual(set(requester.__domains), {"github.com", "githubusercontent.com"})
286287

287288
for allowed in [
288289
"https://api.github.com/request",
@@ -308,6 +309,7 @@ def testAssertUrlAllowed(self):
308309

309310
# custom (Enterprise) requester with prefix
310311
requester = github.Github(base_url="https://prod.ghe.local/github-api/").requester
312+
self.assertEqual(set(requester.__domains), {"prod.ghe.local"})
311313

312314
for allowed in [
313315
"https://prod.ghe.local/github-api/request",
@@ -335,6 +337,37 @@ def testAssertUrlAllowed(self):
335337
requester.__assertUrlAllowed(not_allowed)
336338
self.assertEqual(exc.exception.args, (arg,))
337339

340+
# custom (Enterprise) requester with api subdomain and prefix
341+
requester = github.Github(base_url="https://api.prod.ghe.local/github-api/").requester
342+
self.assertEqual(set(requester.__domains), {"api.prod.ghe.local", "prod.ghe.local"})
343+
344+
for allowed in [
345+
"https://api.prod.ghe.local/github-api/request",
346+
"https://prod.ghe.local/path",
347+
"https://uploads.prod.ghe.local/path",
348+
"https://status.prod.ghe.local/path",
349+
]:
350+
requester.__assertUrlAllowed(allowed)
351+
352+
for not_allowed, arg in [
353+
("https://api.prod.ghe.local/path", "/path"),
354+
("https://ghe.local/path", "ghe.local"),
355+
("https://api.github.com/request", "api.github.com"),
356+
("https://github.com/path", "github.com"),
357+
("https://uploads.github.com/path", "uploads.github.com"),
358+
("https://status.github.com/path", "status.github.com"),
359+
("https://githubusercontent.com/path", "githubusercontent.com"),
360+
("https://objects.githubusercontent.com/path", "objects.githubusercontent.com"),
361+
(
362+
"https://release-assets.githubusercontent.com/path",
363+
"release-assets.githubusercontent.com",
364+
),
365+
("https://example.com/", "example.com"),
366+
]:
367+
with self.assertRaises(AssertionError) as exc:
368+
requester.__assertUrlAllowed(not_allowed)
369+
self.assertEqual(exc.exception.args, (arg,))
370+
338371
def testMakeAbsoluteUrl(self):
339372
# default github.com requester
340373
requester = self.g.requester

0 commit comments

Comments
 (0)