Coverage for src/local_deep_research/journal_quality/data_sources/_openalex_common.py: 97%
47 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
1"""Shared helpers for the two OpenAlex snapshot fetchers.
3Sources and institutions both pull from the OpenAlex S3 bucket, translate
4``s3://`` URLs to the public HTTPS gateway, and defend-in-depth against
5a compromised manifest by allowlisting the ``s3://openalex/`` prefix.
6This file owns those three shared symbols so ``openalex.py`` and
7``institutions.py`` don't duplicate them (and can't drift). It also
8owns the per-partition streaming helper they both use to iterate
9records with consistent malformed-line suppression and tmp-file
10lifecycle.
11"""
13from __future__ import annotations
15import gzip
16import json
17from pathlib import Path
18from typing import Callable, Iterator, Tuple
20from loguru import logger
22# Public OpenAlex snapshot — CC0, no auth, no rate limits.
23# Manifest format documented at:
24# https://docs.openalex.org/download-all-data/snapshot-data-format
25# Each entry in ``manifest["entries"]`` has ``url`` (s3://...) and
26# ``meta.content_length`` / ``meta.record_count``. We translate s3:// to
27# the public HTTPS gateway so we don't need boto3.
28OPENALEX_S3_BASE = "https://openalex.s3.amazonaws.com"
30# Only fetch parts hosted under the OpenAlex public S3 bucket — defense
31# in depth on top of safe_get's private-IP block. A compromised or
32# malformed manifest could otherwise list arbitrary attacker-controlled
33# URLs.
34OPENALEX_MANIFEST_ALLOWED_PREFIX = "s3://openalex/"
37def s3_to_https(s3_url: str) -> str:
38 """Translate ``s3://openalex/...`` to the public HTTPS gateway."""
39 return s3_url.replace(
40 OPENALEX_MANIFEST_ALLOWED_PREFIX, OPENALEX_S3_BASE + "/", 1
41 )
44def validate_manifest_entries(entries: list[dict], label: str) -> None:
45 """Refuse to fetch if any manifest entry escapes the S3 allowlist.
47 Defense-in-depth: a compromised or tampered manifest could list
48 URLs outside the OpenAlex bucket. Refusing the whole fetch rather
49 than fetching some-and-not-others keeps failure modes simple.
50 """
51 for entry in entries:
52 raw = entry.get("url", "")
53 if not raw.startswith(OPENALEX_MANIFEST_ALLOWED_PREFIX):
54 raise ValueError(
55 f"{label} manifest contains disallowed URL "
56 f"(must start with {OPENALEX_MANIFEST_ALLOWED_PREFIX!r}): "
57 f"{raw!r}"
58 )
61# Per-partition retry budget. The default ``safe_get_with_retries``
62# budget (3 retries, 1-2-4 s backoff = ~7 s total) is sized for small
63# request bodies and trips on a sustained mid-stream S3 hiccup: every
64# retry of a ~5–10 MB partition that lands inside the same bad window
65# fails the same way, exhausts the budget in seconds, and aborts the
66# whole 30-partition pull. The release-gate workflow saw this twice in
67# a row on 2026-04-26.
68#
69# 5 retries with 2-5-10-20-40 s backoff rides out a ~75 s S3 blip
70# instead, while still bounding total wall-clock per partition at
71# roughly ``timeout * 6 + 77 s`` — well inside the 45 min job timeout
72# even if every partition needed all retries.
73_PARTITION_MAX_RETRIES = 5
74_PARTITION_BACKOFF_SECONDS = (2, 5, 10, 20, 40)
77def iter_partitions(
78 entries: list[dict],
79 data_dir: Path,
80 *,
81 file_prefix: str,
82 label: str,
83 safe_get: Callable,
84 timeout: int = 120,
85 max_retries: int = _PARTITION_MAX_RETRIES,
86 backoff_times: tuple = _PARTITION_BACKOFF_SECONDS,
87) -> Iterator[Tuple[int, int, list[dict]]]:
88 """Download each partition, yielding ``(idx, total_parts, records)``.
90 Shared between ``openalex.py`` and ``institutions.py`` so the
91 tmp-file lifecycle and malformed-JSON suppression (first-10
92 warnings + one "further suppressed" notice) are defined once.
94 The caller iterates ``records`` for per-record work and is
95 responsible for per-partition progress logging and ``progress_cb``
96 invocations — those need caller-specific state (running record
97 count, schema-drift counters) that doesn't belong in the helper.
99 Args:
100 entries: ``manifest["entries"]`` — each dict has ``url``
101 starting with ``s3://openalex/``.
102 data_dir: Directory used for the transient ``.<prefix>_part_<n>.gz``
103 files. Cleaned up even on exception.
104 file_prefix: Leaf prefix for tmp files
105 (e.g. ``openalex_sources`` / ``openalex_institutions``).
106 label: Human-readable label used in log messages
107 (e.g. ``"OpenAlex sources"`` / ``"Institutions"``).
108 safe_get: Dependency-injected HTTP getter (lets the caller
109 pick ``safe_get_with_retries`` without forcing a global
110 import at module load). Must accept ``consume_body=True``
111 so body-stream transients (``ChunkedEncodingError``,
112 ``ReadTimeout``) raised during ``resp.content`` are
113 retried inside the wrapper, not propagated to abort the
114 whole multi-partition pull.
115 timeout: Per-partition HTTP timeout (seconds).
116 max_retries: Per-partition retry budget. Defaults higher than
117 ``safe_get_with_retries``' generic 3 because partition
118 bodies are MB-sized and a mid-stream IncompleteRead aborts
119 the whole multi-partition pull on exhaustion.
120 backoff_times: Per-attempt sleep schedule. Defaults to a
121 longer schedule than the generic ``safe_get_with_retries``
122 (1, 2, 4) so we ride out a sustained S3 blip instead of
123 burning all retries inside the same bad window.
124 """
125 malformed_total = 0
126 total_parts = len(entries)
128 for idx, entry in enumerate(entries):
129 part_url = s3_to_https(entry["url"])
130 tmp_part = data_dir / f".{file_prefix}_part_{idx}.gz"
131 records: list[dict] = []
133 try:
134 # consume_body=True: an OpenAlex S3 partition is ~10 MB
135 # gzipped. A mid-stream ChunkedEncodingError /
136 # IncompleteRead would otherwise abort the whole 30+
137 # partition pull. With consume_body, safe_get_with_retries
138 # reads resp.content inside its retry loop and retries
139 # body-stream transients the same way it retries
140 # header-stage failures.
141 resp = safe_get(
142 part_url,
143 timeout=timeout,
144 consume_body=True,
145 max_retries=max_retries,
146 backoff_times=backoff_times,
147 )
148 resp.raise_for_status()
149 tmp_part.write_bytes(resp.content)
151 with gzip.open(tmp_part, "rt", encoding="utf-8") as fh:
152 for line in fh:
153 line = line.strip()
154 if not line: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 continue
156 try:
157 rec = json.loads(line)
158 except (json.JSONDecodeError, ValueError):
159 malformed_total += 1
160 if malformed_total <= 10:
161 logger.warning(
162 f"{label} partition {idx}: skipping "
163 f"malformed JSON line"
164 )
165 elif malformed_total == 11:
166 logger.warning(
167 f"{label} partition {idx}: further "
168 "malformed lines suppressed"
169 )
170 continue
171 records.append(rec)
172 finally:
173 tmp_part.unlink(missing_ok=True)
175 yield idx, total_parts, records
177 if malformed_total:
178 logger.warning(
179 f"{label}: {malformed_total:,} malformed lines skipped across "
180 f"{total_parts} partitions"
181 )