Coverage for src/local_deep_research/journal_quality/data_sources/_openalex_common.py: 97%

47 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 23:15 +0000

1"""Shared helpers for the two OpenAlex snapshot fetchers. 

2 

3Sources and institutions both pull from the OpenAlex S3 bucket, translate 

4``s3://`` URLs to the public HTTPS gateway, and defend-in-depth against 

5a compromised manifest by allowlisting the ``s3://openalex/`` prefix. 

6This file owns those three shared symbols so ``openalex.py`` and 

7``institutions.py`` don't duplicate them (and can't drift). It also 

8owns the per-partition streaming helper they both use to iterate 

9records with consistent malformed-line suppression and tmp-file 

10lifecycle. 

11""" 

12 

13from __future__ import annotations 

14 

15import gzip 

16import json 

17from pathlib import Path 

18from typing import Callable, Iterator, Tuple 

19 

20from loguru import logger 

21 

22# Public OpenAlex snapshot — CC0, no auth, no rate limits. 

23# Manifest format documented at: 

24# https://docs.openalex.org/download-all-data/snapshot-data-format 

25# Each entry in ``manifest["entries"]`` has ``url`` (s3://...) and 

26# ``meta.content_length`` / ``meta.record_count``. We translate s3:// to 

27# the public HTTPS gateway so we don't need boto3. 

28OPENALEX_S3_BASE = "https://openalex.s3.amazonaws.com" 

29 

30# Only fetch parts hosted under the OpenAlex public S3 bucket — defense 

31# in depth on top of safe_get's private-IP block. A compromised or 

32# malformed manifest could otherwise list arbitrary attacker-controlled 

33# URLs. 

34OPENALEX_MANIFEST_ALLOWED_PREFIX = "s3://openalex/" 

35 

36 

37def s3_to_https(s3_url: str) -> str: 

38 """Translate ``s3://openalex/...`` to the public HTTPS gateway.""" 

39 return s3_url.replace( 

40 OPENALEX_MANIFEST_ALLOWED_PREFIX, OPENALEX_S3_BASE + "/", 1 

41 ) 

42 

43 

44def validate_manifest_entries(entries: list[dict], label: str) -> None: 

45 """Refuse to fetch if any manifest entry escapes the S3 allowlist. 

46 

47 Defense-in-depth: a compromised or tampered manifest could list 

48 URLs outside the OpenAlex bucket. Refusing the whole fetch rather 

49 than fetching some-and-not-others keeps failure modes simple. 

50 """ 

51 for entry in entries: 

52 raw = entry.get("url", "") 

53 if not raw.startswith(OPENALEX_MANIFEST_ALLOWED_PREFIX): 

54 raise ValueError( 

55 f"{label} manifest contains disallowed URL " 

56 f"(must start with {OPENALEX_MANIFEST_ALLOWED_PREFIX!r}): " 

57 f"{raw!r}" 

58 ) 

59 

60 

61# Per-partition retry budget. The default ``safe_get_with_retries`` 

62# budget (3 retries, 1-2-4 s backoff = ~7 s total) is sized for small 

63# request bodies and trips on a sustained mid-stream S3 hiccup: every 

64# retry of a ~5–10 MB partition that lands inside the same bad window 

65# fails the same way, exhausts the budget in seconds, and aborts the 

66# whole 30-partition pull. The release-gate workflow saw this twice in 

67# a row on 2026-04-26. 

68# 

69# 5 retries with 2-5-10-20-40 s backoff rides out a ~75 s S3 blip 

70# instead, while still bounding total wall-clock per partition at 

71# roughly ``timeout * 6 + 77 s`` — well inside the 45 min job timeout 

72# even if every partition needed all retries. 

73_PARTITION_MAX_RETRIES = 5 

74_PARTITION_BACKOFF_SECONDS = (2, 5, 10, 20, 40) 

75 

76 

77def iter_partitions( 

78 entries: list[dict], 

79 data_dir: Path, 

80 *, 

81 file_prefix: str, 

82 label: str, 

83 safe_get: Callable, 

84 timeout: int = 120, 

85 max_retries: int = _PARTITION_MAX_RETRIES, 

86 backoff_times: tuple = _PARTITION_BACKOFF_SECONDS, 

87) -> Iterator[Tuple[int, int, list[dict]]]: 

88 """Download each partition, yielding ``(idx, total_parts, records)``. 

89 

90 Shared between ``openalex.py`` and ``institutions.py`` so the 

91 tmp-file lifecycle and malformed-JSON suppression (first-10 

92 warnings + one "further suppressed" notice) are defined once. 

93 

94 The caller iterates ``records`` for per-record work and is 

95 responsible for per-partition progress logging and ``progress_cb`` 

96 invocations — those need caller-specific state (running record 

97 count, schema-drift counters) that doesn't belong in the helper. 

98 

99 Args: 

100 entries: ``manifest["entries"]`` — each dict has ``url`` 

101 starting with ``s3://openalex/``. 

102 data_dir: Directory used for the transient ``.<prefix>_part_<n>.gz`` 

103 files. Cleaned up even on exception. 

104 file_prefix: Leaf prefix for tmp files 

105 (e.g. ``openalex_sources`` / ``openalex_institutions``). 

106 label: Human-readable label used in log messages 

107 (e.g. ``"OpenAlex sources"`` / ``"Institutions"``). 

108 safe_get: Dependency-injected HTTP getter (lets the caller 

109 pick ``safe_get_with_retries`` without forcing a global 

110 import at module load). Must accept ``consume_body=True`` 

111 so body-stream transients (``ChunkedEncodingError``, 

112 ``ReadTimeout``) raised during ``resp.content`` are 

113 retried inside the wrapper, not propagated to abort the 

114 whole multi-partition pull. 

115 timeout: Per-partition HTTP timeout (seconds). 

116 max_retries: Per-partition retry budget. Defaults higher than 

117 ``safe_get_with_retries``' generic 3 because partition 

118 bodies are MB-sized and a mid-stream IncompleteRead aborts 

119 the whole multi-partition pull on exhaustion. 

120 backoff_times: Per-attempt sleep schedule. Defaults to a 

121 longer schedule than the generic ``safe_get_with_retries`` 

122 (1, 2, 4) so we ride out a sustained S3 blip instead of 

123 burning all retries inside the same bad window. 

124 """ 

125 malformed_total = 0 

126 total_parts = len(entries) 

127 

128 for idx, entry in enumerate(entries): 

129 part_url = s3_to_https(entry["url"]) 

130 tmp_part = data_dir / f".{file_prefix}_part_{idx}.gz" 

131 records: list[dict] = [] 

132 

133 try: 

134 # consume_body=True: an OpenAlex S3 partition is ~10 MB 

135 # gzipped. A mid-stream ChunkedEncodingError / 

136 # IncompleteRead would otherwise abort the whole 30+ 

137 # partition pull. With consume_body, safe_get_with_retries 

138 # reads resp.content inside its retry loop and retries 

139 # body-stream transients the same way it retries 

140 # header-stage failures. 

141 resp = safe_get( 

142 part_url, 

143 timeout=timeout, 

144 consume_body=True, 

145 max_retries=max_retries, 

146 backoff_times=backoff_times, 

147 ) 

148 resp.raise_for_status() 

149 tmp_part.write_bytes(resp.content) 

150 

151 with gzip.open(tmp_part, "rt", encoding="utf-8") as fh: 

152 for line in fh: 

153 line = line.strip() 

154 if not line: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 continue 

156 try: 

157 rec = json.loads(line) 

158 except (json.JSONDecodeError, ValueError): 

159 malformed_total += 1 

160 if malformed_total <= 10: 

161 logger.warning( 

162 f"{label} partition {idx}: skipping " 

163 f"malformed JSON line" 

164 ) 

165 elif malformed_total == 11: 

166 logger.warning( 

167 f"{label} partition {idx}: further " 

168 "malformed lines suppressed" 

169 ) 

170 continue 

171 records.append(rec) 

172 finally: 

173 tmp_part.unlink(missing_ok=True) 

174 

175 yield idx, total_parts, records 

176 

177 if malformed_total: 

178 logger.warning( 

179 f"{label}: {malformed_total:,} malformed lines skipped across " 

180 f"{total_parts} partitions" 

181 )