Coverage for src/local_deep_research/web/services/socket_service.py: 96%

209 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 23:15 +0000

1from threading import Lock 

2from typing import Any 

3 

4from flask import Flask, request, session 

5from flask_socketio import SocketIO 

6from loguru import logger 

7 

8from ...constants import ResearchStatus 

9from ...database.encrypted_db import db_manager 

10from ...database.session_passwords import session_password_store 

11from ..routes.globals import get_active_research_snapshot 

12 

13 

14class SocketIOService: 

15 """ 

16 Singleton class for managing SocketIO connections and subscriptions. 

17 """ 

18 

19 _instance = None 

20 

21 def __new__(cls, *args: Any, app: Flask | None = None, **kwargs: Any): 

22 """ 

23 Args: 

24 app: The Flask app to bind this service to. It must be specified 

25 the first time this is called and the singleton instance is 

26 created, but will be ignored after that. 

27 *args: Arguments to pass to the superclass's __new__ method. 

28 **kwargs: Keyword arguments to pass to the superclass's __new__ method. 

29 """ 

30 if not cls._instance: 

31 if app is None: 

32 raise ValueError( 

33 "Flask app must be specified to create a SocketIOService instance." 

34 ) 

35 cls._instance = super(SocketIOService, cls).__new__( 

36 cls, *args, **kwargs 

37 ) 

38 cls._instance.__init_singleton(app) 

39 return cls._instance 

40 

41 def __init_singleton(self, app: Flask) -> None: 

42 """ 

43 Initializes the singleton instance. 

44 

45 Args: 

46 app: The app to bind this service to. 

47 

48 """ 

49 self.__app = app # Store the Flask app reference 

50 

51 # Determine WebSocket CORS policy from env var or default 

52 from ...settings.env_registry import get_env_setting 

53 

54 ws_origins_env = get_env_setting("security.websocket.allowed_origins") 

55 socketio_cors: str | list[str] | None 

56 if ws_origins_env is not None: 

57 if ws_origins_env == "*": 

58 socketio_cors = "*" 

59 elif ws_origins_env: 

60 socketio_cors = [o.strip() for o in ws_origins_env.split(",")] 

61 else: 

62 socketio_cors = None 

63 else: 

64 # No env var set — preserve existing permissive default 

65 socketio_cors = "*" 

66 

67 if socketio_cors is None: 

68 logger.info( 

69 "Socket.IO CORS: same-origin only (set LDR_SECURITY_WEBSOCKET_ALLOWED_ORIGINS to configure)" 

70 ) 

71 elif socketio_cors == "*": 

72 logger.debug("Socket.IO CORS: all origins allowed") 

73 else: 

74 logger.info(f"Socket.IO CORS: restricted to {socketio_cors}") 

75 

76 self.__socketio = SocketIO( 

77 app, 

78 cors_allowed_origins=socketio_cors, 

79 async_mode="threading", 

80 path="/socket.io", 

81 logger=False, 

82 engineio_logger=False, 

83 ping_timeout=20, 

84 ping_interval=5, 

85 ) 

86 

87 # Socket subscription tracking. 

88 self.__socket_subscriptions: dict[str, Any] = {} 

89 # Set to false to disable logging in the event handlers. This can 

90 # be necessary because it will sometimes run the handlers directly 

91 # during a call to `emit` that was made in a logging handler. 

92 self.__logging_enabled = True 

93 # Protects access to shared state. 

94 self.__lock = Lock() 

95 

96 # Register events. 

97 @self.__socketio.on("connect") 

98 def on_connect(): 

99 return self.__handle_connect(request) 

100 

101 @self.__socketio.on("disconnect") 

102 def on_disconnect(reason: str): 

103 self.__handle_disconnect(request, reason) 

104 

105 @self.__socketio.on("subscribe_to_research") 

106 def on_subscribe(data): 

107 self.__handle_subscribe(data, request) 

108 

109 # Backwards-compatible alias: the JS client emits 'join' on subscribe. 

110 # Without this, the catch-up snapshot in __handle_subscribe never 

111 # fires and per-client targeting falls through to broadcast. 

112 @self.__socketio.on("join") 

113 def on_join(data): 

114 self.__handle_subscribe(data, request) 

115 

116 @self.__socketio.on("leave") 

117 def on_leave(data): 

118 self.__handle_unsubscribe(data, request) 

119 

120 @self.__socketio.on("unsubscribe_from_research") 

121 def on_unsubscribe(data): 

122 self.__handle_unsubscribe(data, request) 

123 

124 @self.__socketio.on_error 

125 def on_error(e): 

126 return self.__handle_socket_error(e) 

127 

128 @self.__socketio.on_error_default 

129 def on_default_error(e): 

130 return self.__handle_default_error(e) 

131 

132 def __log_info(self, message: str, *args: Any, **kwargs: Any) -> None: 

133 """Log an info message.""" 

134 if self.__logging_enabled: 

135 logger.info(message, *args, **kwargs) 

136 

137 def __log_error(self, message: str, *args: Any, **kwargs: Any) -> None: 

138 """Log an error message.""" 

139 if self.__logging_enabled: 

140 logger.error(message, *args, **kwargs) 

141 

142 def __log_exception(self, message: str, *args: Any, **kwargs: Any) -> None: 

143 """Log an exception.""" 

144 if self.__logging_enabled: 

145 logger.exception(message, *args, **kwargs) 

146 

147 def emit_socket_event(self, event, data, room=None): 

148 """ 

149 Emit a socket event to clients. 

150 

151 Args: 

152 event: The event name to emit 

153 data: The data to send with the event 

154 room: Optional room ID to send to specific client 

155 

156 Returns: 

157 bool: True if emission was successful, False otherwise 

158 """ 

159 try: 

160 # If room is specified, only emit to that room 

161 if room: 

162 self.__socketio.emit(event, data, room=room) 

163 else: 

164 # Otherwise broadcast to all 

165 self.__socketio.emit(event, data) 

166 return True 

167 except Exception: 

168 logger.exception(f"Error emitting socket event {event}") 

169 return False 

170 

171 def emit_to_subscribers( 

172 self, event_base, research_id, data, enable_logging: bool = True 

173 ): 

174 """ 

175 Emit an event to all subscribers of a specific research. 

176 

177 Args: 

178 event_base: Base event name (will be formatted with research_id) 

179 research_id: ID of the research 

180 data: The data to send with the event 

181 enable_logging: If set to false, this will disable all logging, 

182 which is useful if we are calling this inside of a logging 

183 handler. 

184 

185 Returns: 

186 bool: True if emission was successful, False otherwise 

187 

188 """ 

189 if not enable_logging: 

190 self.__logging_enabled = False 

191 

192 try: 

193 full_event = f"{event_base}_{research_id}" 

194 

195 # Emit only to specific subscribers (no broadcast) to avoid 

196 # duplicate messages and reduce server load under concurrency 

197 with self.__lock: 

198 subscriptions = self.__socket_subscriptions.get(research_id) 

199 if subscriptions: 

200 subscriptions = ( 

201 subscriptions.copy() 

202 ) # snapshot avoids RuntimeError 

203 else: 

204 subscriptions = None 

205 if subscriptions is not None: 

206 for sid in subscriptions: 

207 try: 

208 self.__socketio.emit(full_event, data, room=sid) 

209 except Exception: 

210 self.__log_exception( 

211 f"Error emitting to subscriber {sid}" 

212 ) 

213 # When no targeted subscribers exist yet, drop the event. 

214 # The catch-up snapshot in __handle_subscribe replays the 

215 # latest progress on subscribe, so early-arriving events 

216 # are recovered correctly without a cross-user broadcast. 

217 

218 return True 

219 except Exception: 

220 self.__log_exception( 

221 f"Error emitting to subscribers for research {research_id}" 

222 ) 

223 return False 

224 finally: 

225 self.__logging_enabled = True 

226 

227 def remove_subscriptions_for_research(self, research_id: str) -> None: 

228 """Remove all socket subscriptions for a completed research.""" 

229 with self.__lock: 

230 removed = self.__socket_subscriptions.pop(research_id, None) 

231 if removed is not None: 

232 self.__log_info( 

233 f"Removed {len(removed)} subscription(s) for research {research_id}" 

234 ) 

235 

236 def __handle_connect(self, request): 

237 """Handle client connection""" 

238 username = session.get("username") 

239 if not username: 

240 self.__log_info( 

241 f"Rejected unauthenticated WebSocket connection from {request.sid}" 

242 ) 

243 return False 

244 if not db_manager.is_user_connected(username): 

245 # Cookie is valid but the per-user DB engine isn't open yet (race vs first 

246 # XHR after page load, gunicorn worker restart, or idle eviction). Lazily 

247 # open it using the password the user authenticated with at login. 

248 session_id = session.get("session_id") 

249 password = ( 

250 session_password_store.get_session_password( 

251 username, session_id 

252 ) 

253 if session_id 

254 else None 

255 ) 

256 if not password: 

257 self.__log_info( 

258 f"Rejected WebSocket connection for {username}: no active DB session and no stored password" 

259 ) 

260 return False 

261 try: 

262 db_manager.open_user_database(username, password) 

263 except Exception as e: 

264 # Use __log_error (not __log_exception) so loguru cannot include 

265 # the `password` local in a diagnose=True traceback. 

266 self.__log_error( 

267 f"Lazy DB open failed for {username} at WebSocket connect: {type(e).__name__}" 

268 ) 

269 return False 

270 self.__log_info(f"Client connected: {request.sid} (user: {username})") 

271 return True 

272 

273 def __handle_disconnect(self, request, reason: str): 

274 """Handle client disconnection""" 

275 try: 

276 self.__log_info( 

277 f"Client {request.sid} disconnected because: {reason}" 

278 ) 

279 # Clean up subscriptions for this client. 

280 # __socket_subscriptions is keyed by research_id → set of sids, 

281 # so we iterate all entries and discard the disconnecting sid. 

282 with self.__lock: 

283 empty_keys = [] 

284 for research_id, sids in self.__socket_subscriptions.items(): 

285 sids.discard(request.sid) 

286 if not sids: 

287 empty_keys.append(research_id) 

288 for key in empty_keys: 

289 del self.__socket_subscriptions[key] 

290 self.__log_info(f"Removed subscription for client {request.sid}") 

291 

292 # Clean up any thread-local database sessions that may have been 

293 # created during socket handler execution. This prevents file 

294 # descriptor leaks from unclosed SQLAlchemy sessions. 

295 try: 

296 from ...database.thread_local_session import ( 

297 cleanup_current_thread, 

298 ) 

299 

300 cleanup_current_thread() 

301 except ImportError: 

302 pass # Module not available, skip cleanup 

303 except Exception: 

304 self.__log_exception( 

305 "Error cleaning up thread session on disconnect" 

306 ) 

307 except Exception as e: 

308 self.__log_exception(f"Error handling disconnect: {e}") 

309 

310 def __handle_subscribe(self, data, request): 

311 """Handle client subscription to research updates.""" 

312 research_id = data.get("research_id") 

313 if not research_id: 

314 return 

315 

316 # Verify the connected user actually owns this research before 

317 # subscribing. The in-memory `_active_research` snapshot is keyed 

318 # only by research_id (no user tuple), so without this guard any 

319 # logged-in user could subscribe to any guessed/leaked research 

320 # UUID and receive its progress events. The per-user encrypted DB 

321 # is the ownership boundary: if the research row doesn't exist in 

322 # the user's DB, they don't own it. 

323 username = session.get("username") 

324 if not username or not self._user_owns_research(username, research_id): 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true

325 self.__log_info( 

326 f"Rejected subscribe from {request.sid}: user does not own research {research_id}" 

327 ) 

328 return 

329 

330 with self.__lock: 

331 if research_id not in self.__socket_subscriptions: 

332 self.__socket_subscriptions[research_id] = set() 

333 self.__socket_subscriptions[research_id].add(request.sid) 

334 self.__log_info( 

335 f"Client {request.sid} subscribed to research {research_id}" 

336 ) 

337 

338 # Send current status immediately if available in active research 

339 snapshot = get_active_research_snapshot(research_id) 

340 if snapshot is not None: 

341 progress = snapshot["progress"] 

342 latest_log = snapshot["log"][-1] if snapshot["log"] else None 

343 

344 if latest_log: 

345 self.emit_socket_event( 

346 f"progress_{research_id}", 

347 { 

348 "progress": progress, 

349 "message": latest_log.get("message", "Processing..."), 

350 "status": ResearchStatus.IN_PROGRESS, 

351 "log_entry": latest_log, 

352 }, 

353 room=request.sid, 

354 ) 

355 

356 @staticmethod 

357 def _user_owns_research(username: str, research_id: str) -> bool: 

358 """Return True if the given user owns this research / benchmark id. 

359 

360 Used as the authorization boundary for WebSocket subscriptions — 

361 ownership is checked against the user's encrypted SQLite database, 

362 which is the per-user data partition. A static helper so unit 

363 tests can exercise the authz logic without standing up the 

364 singleton/Flask app. 

365 

366 Recognizes both normal research (``ResearchHistory``, UUID id) and 

367 benchmark runs (``BenchmarkRun``, integer id) — the benchmark page 

368 subscribes with its ``BenchmarkRun.id``, which lives in the same 

369 per-user DB. Both checks stay scoped to the caller's own database, 

370 so no cross-user access is introduced. 

371 """ 

372 try: 

373 from ...database.session_context import get_user_db_session 

374 from ...database.models import ResearchHistory 

375 

376 with get_user_db_session(username) as db: 

377 if ( 

378 db.query(ResearchHistory.id) 

379 .filter(ResearchHistory.id == research_id) 

380 .first() 

381 is not None 

382 ): 

383 return True 

384 

385 # Benchmark pages subscribe with their BenchmarkRun.id. 

386 # Recognize the user's own benchmark runs so the ownership 

387 # gate doesn't drop benchmark live progress (regression vs. 

388 # the removed cross-user broadcast). research_id stays a 

389 # string (never coerced to int — IDs are strings/UUIDs 

390 # repo-wide); SQLite applies numeric affinity to match the 

391 # Integer column. Only attempt this for numeric ids. 

392 if str(research_id).isdigit(): 

393 from ...database.models.benchmark import BenchmarkRun 

394 

395 return ( 

396 db.query(BenchmarkRun.id) 

397 .filter(BenchmarkRun.id == research_id) 

398 .first() 

399 is not None 

400 ) 

401 return False 

402 except Exception: 

403 # Conservative: deny on any DB-open or query failure so a 

404 # transient infra error never silently widens authz. 

405 logger.opt(exception=True).warning( 

406 "Failed to verify research ownership for socket subscribe" 

407 ) 

408 return False 

409 

410 def __handle_unsubscribe(self, data, request): 

411 """Handle client unsubscribe from research updates.""" 

412 research_id = ( 

413 data.get("research_id") if isinstance(data, dict) else None 

414 ) 

415 if not research_id: 

416 return 

417 

418 # Symmetric with __handle_subscribe: require the caller to own the 

419 # research before mutating the per-research subscription set. The 

420 # practical impact of an unguarded unsubscribe is small (no data 

421 # exfiltration; subscribe is already guarded), but it keeps the 

422 # authz boundary consistent and avoids log spam from spoofed sids. 

423 username = session.get("username") 

424 if not username or not self._user_owns_research(username, research_id): 

425 self.__log_info( 

426 f"Rejected unsubscribe from {request.sid}: user does not own research {research_id}" 

427 ) 

428 return 

429 

430 with self.__lock: 

431 subs = self.__socket_subscriptions.get(research_id) 

432 if subs: 

433 subs.discard(request.sid) 

434 # Prune empty sets so the dict doesn't grow unbounded with 

435 # stale research_ids over long server runtimes. 

436 if not subs: 

437 self.__socket_subscriptions.pop(research_id, None) 

438 self.__log_info( 

439 f"Client {request.sid} unsubscribed from research {research_id}" 

440 ) 

441 

442 def __handle_socket_error(self, e): 

443 """Handle Socket.IO errors""" 

444 self.__log_exception(f"Socket.IO error: {str(e)}") 

445 # Don't propagate exceptions to avoid crashing the server 

446 return False 

447 

448 def __handle_default_error(self, e): 

449 """Handle unhandled Socket.IO errors""" 

450 self.__log_exception(f"Unhandled Socket.IO error: {str(e)}") 

451 # Don't propagate exceptions to avoid crashing the server 

452 return False 

453 

454 def run(self, host: str, port: int, debug: bool = False) -> None: 

455 """ 

456 Runs the SocketIO server. 

457 

458 Args: 

459 host: The hostname to bind the server to. 

460 port: The port number to listen on. 

461 debug: Whether to run in debug mode. Defaults to False. 

462 

463 """ 

464 # Suppress Server header to prevent version information disclosure 

465 # This must be done before starting the server because Werkzeug adds 

466 # the header at the HTTP layer, not WSGI layer 

467 try: 

468 from werkzeug.serving import WSGIRequestHandler 

469 

470 WSGIRequestHandler.version_string = lambda self: "" # type: ignore[method-assign] 

471 logger.debug("Suppressed Server header for security") 

472 except ImportError: 

473 logger.warning( 

474 "Could not suppress Server header - werkzeug not found" 

475 ) 

476 

477 logger.info(f"Starting web server on {host}:{port} (debug: {debug})") 

478 self.__socketio.run( 

479 self.__app, # Use the stored Flask app reference 

480 debug=debug, 

481 host=host, 

482 port=port, 

483 allow_unsafe_werkzeug=True, 

484 use_reloader=False, 

485 )