Coverage for src/local_deep_research/web/services/socket_service.py: 96%
209 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
1from threading import Lock
2from typing import Any
4from flask import Flask, request, session
5from flask_socketio import SocketIO
6from loguru import logger
8from ...constants import ResearchStatus
9from ...database.encrypted_db import db_manager
10from ...database.session_passwords import session_password_store
11from ..routes.globals import get_active_research_snapshot
14class SocketIOService:
15 """
16 Singleton class for managing SocketIO connections and subscriptions.
17 """
19 _instance = None
21 def __new__(cls, *args: Any, app: Flask | None = None, **kwargs: Any):
22 """
23 Args:
24 app: The Flask app to bind this service to. It must be specified
25 the first time this is called and the singleton instance is
26 created, but will be ignored after that.
27 *args: Arguments to pass to the superclass's __new__ method.
28 **kwargs: Keyword arguments to pass to the superclass's __new__ method.
29 """
30 if not cls._instance:
31 if app is None:
32 raise ValueError(
33 "Flask app must be specified to create a SocketIOService instance."
34 )
35 cls._instance = super(SocketIOService, cls).__new__(
36 cls, *args, **kwargs
37 )
38 cls._instance.__init_singleton(app)
39 return cls._instance
41 def __init_singleton(self, app: Flask) -> None:
42 """
43 Initializes the singleton instance.
45 Args:
46 app: The app to bind this service to.
48 """
49 self.__app = app # Store the Flask app reference
51 # Determine WebSocket CORS policy from env var or default
52 from ...settings.env_registry import get_env_setting
54 ws_origins_env = get_env_setting("security.websocket.allowed_origins")
55 socketio_cors: str | list[str] | None
56 if ws_origins_env is not None:
57 if ws_origins_env == "*":
58 socketio_cors = "*"
59 elif ws_origins_env:
60 socketio_cors = [o.strip() for o in ws_origins_env.split(",")]
61 else:
62 socketio_cors = None
63 else:
64 # No env var set — preserve existing permissive default
65 socketio_cors = "*"
67 if socketio_cors is None:
68 logger.info(
69 "Socket.IO CORS: same-origin only (set LDR_SECURITY_WEBSOCKET_ALLOWED_ORIGINS to configure)"
70 )
71 elif socketio_cors == "*":
72 logger.debug("Socket.IO CORS: all origins allowed")
73 else:
74 logger.info(f"Socket.IO CORS: restricted to {socketio_cors}")
76 self.__socketio = SocketIO(
77 app,
78 cors_allowed_origins=socketio_cors,
79 async_mode="threading",
80 path="/socket.io",
81 logger=False,
82 engineio_logger=False,
83 ping_timeout=20,
84 ping_interval=5,
85 )
87 # Socket subscription tracking.
88 self.__socket_subscriptions: dict[str, Any] = {}
89 # Set to false to disable logging in the event handlers. This can
90 # be necessary because it will sometimes run the handlers directly
91 # during a call to `emit` that was made in a logging handler.
92 self.__logging_enabled = True
93 # Protects access to shared state.
94 self.__lock = Lock()
96 # Register events.
97 @self.__socketio.on("connect")
98 def on_connect():
99 return self.__handle_connect(request)
101 @self.__socketio.on("disconnect")
102 def on_disconnect(reason: str):
103 self.__handle_disconnect(request, reason)
105 @self.__socketio.on("subscribe_to_research")
106 def on_subscribe(data):
107 self.__handle_subscribe(data, request)
109 # Backwards-compatible alias: the JS client emits 'join' on subscribe.
110 # Without this, the catch-up snapshot in __handle_subscribe never
111 # fires and per-client targeting falls through to broadcast.
112 @self.__socketio.on("join")
113 def on_join(data):
114 self.__handle_subscribe(data, request)
116 @self.__socketio.on("leave")
117 def on_leave(data):
118 self.__handle_unsubscribe(data, request)
120 @self.__socketio.on("unsubscribe_from_research")
121 def on_unsubscribe(data):
122 self.__handle_unsubscribe(data, request)
124 @self.__socketio.on_error
125 def on_error(e):
126 return self.__handle_socket_error(e)
128 @self.__socketio.on_error_default
129 def on_default_error(e):
130 return self.__handle_default_error(e)
132 def __log_info(self, message: str, *args: Any, **kwargs: Any) -> None:
133 """Log an info message."""
134 if self.__logging_enabled:
135 logger.info(message, *args, **kwargs)
137 def __log_error(self, message: str, *args: Any, **kwargs: Any) -> None:
138 """Log an error message."""
139 if self.__logging_enabled:
140 logger.error(message, *args, **kwargs)
142 def __log_exception(self, message: str, *args: Any, **kwargs: Any) -> None:
143 """Log an exception."""
144 if self.__logging_enabled:
145 logger.exception(message, *args, **kwargs)
147 def emit_socket_event(self, event, data, room=None):
148 """
149 Emit a socket event to clients.
151 Args:
152 event: The event name to emit
153 data: The data to send with the event
154 room: Optional room ID to send to specific client
156 Returns:
157 bool: True if emission was successful, False otherwise
158 """
159 try:
160 # If room is specified, only emit to that room
161 if room:
162 self.__socketio.emit(event, data, room=room)
163 else:
164 # Otherwise broadcast to all
165 self.__socketio.emit(event, data)
166 return True
167 except Exception:
168 logger.exception(f"Error emitting socket event {event}")
169 return False
171 def emit_to_subscribers(
172 self, event_base, research_id, data, enable_logging: bool = True
173 ):
174 """
175 Emit an event to all subscribers of a specific research.
177 Args:
178 event_base: Base event name (will be formatted with research_id)
179 research_id: ID of the research
180 data: The data to send with the event
181 enable_logging: If set to false, this will disable all logging,
182 which is useful if we are calling this inside of a logging
183 handler.
185 Returns:
186 bool: True if emission was successful, False otherwise
188 """
189 if not enable_logging:
190 self.__logging_enabled = False
192 try:
193 full_event = f"{event_base}_{research_id}"
195 # Emit only to specific subscribers (no broadcast) to avoid
196 # duplicate messages and reduce server load under concurrency
197 with self.__lock:
198 subscriptions = self.__socket_subscriptions.get(research_id)
199 if subscriptions:
200 subscriptions = (
201 subscriptions.copy()
202 ) # snapshot avoids RuntimeError
203 else:
204 subscriptions = None
205 if subscriptions is not None:
206 for sid in subscriptions:
207 try:
208 self.__socketio.emit(full_event, data, room=sid)
209 except Exception:
210 self.__log_exception(
211 f"Error emitting to subscriber {sid}"
212 )
213 # When no targeted subscribers exist yet, drop the event.
214 # The catch-up snapshot in __handle_subscribe replays the
215 # latest progress on subscribe, so early-arriving events
216 # are recovered correctly without a cross-user broadcast.
218 return True
219 except Exception:
220 self.__log_exception(
221 f"Error emitting to subscribers for research {research_id}"
222 )
223 return False
224 finally:
225 self.__logging_enabled = True
227 def remove_subscriptions_for_research(self, research_id: str) -> None:
228 """Remove all socket subscriptions for a completed research."""
229 with self.__lock:
230 removed = self.__socket_subscriptions.pop(research_id, None)
231 if removed is not None:
232 self.__log_info(
233 f"Removed {len(removed)} subscription(s) for research {research_id}"
234 )
236 def __handle_connect(self, request):
237 """Handle client connection"""
238 username = session.get("username")
239 if not username:
240 self.__log_info(
241 f"Rejected unauthenticated WebSocket connection from {request.sid}"
242 )
243 return False
244 if not db_manager.is_user_connected(username):
245 # Cookie is valid but the per-user DB engine isn't open yet (race vs first
246 # XHR after page load, gunicorn worker restart, or idle eviction). Lazily
247 # open it using the password the user authenticated with at login.
248 session_id = session.get("session_id")
249 password = (
250 session_password_store.get_session_password(
251 username, session_id
252 )
253 if session_id
254 else None
255 )
256 if not password:
257 self.__log_info(
258 f"Rejected WebSocket connection for {username}: no active DB session and no stored password"
259 )
260 return False
261 try:
262 db_manager.open_user_database(username, password)
263 except Exception as e:
264 # Use __log_error (not __log_exception) so loguru cannot include
265 # the `password` local in a diagnose=True traceback.
266 self.__log_error(
267 f"Lazy DB open failed for {username} at WebSocket connect: {type(e).__name__}"
268 )
269 return False
270 self.__log_info(f"Client connected: {request.sid} (user: {username})")
271 return True
273 def __handle_disconnect(self, request, reason: str):
274 """Handle client disconnection"""
275 try:
276 self.__log_info(
277 f"Client {request.sid} disconnected because: {reason}"
278 )
279 # Clean up subscriptions for this client.
280 # __socket_subscriptions is keyed by research_id → set of sids,
281 # so we iterate all entries and discard the disconnecting sid.
282 with self.__lock:
283 empty_keys = []
284 for research_id, sids in self.__socket_subscriptions.items():
285 sids.discard(request.sid)
286 if not sids:
287 empty_keys.append(research_id)
288 for key in empty_keys:
289 del self.__socket_subscriptions[key]
290 self.__log_info(f"Removed subscription for client {request.sid}")
292 # Clean up any thread-local database sessions that may have been
293 # created during socket handler execution. This prevents file
294 # descriptor leaks from unclosed SQLAlchemy sessions.
295 try:
296 from ...database.thread_local_session import (
297 cleanup_current_thread,
298 )
300 cleanup_current_thread()
301 except ImportError:
302 pass # Module not available, skip cleanup
303 except Exception:
304 self.__log_exception(
305 "Error cleaning up thread session on disconnect"
306 )
307 except Exception as e:
308 self.__log_exception(f"Error handling disconnect: {e}")
310 def __handle_subscribe(self, data, request):
311 """Handle client subscription to research updates."""
312 research_id = data.get("research_id")
313 if not research_id:
314 return
316 # Verify the connected user actually owns this research before
317 # subscribing. The in-memory `_active_research` snapshot is keyed
318 # only by research_id (no user tuple), so without this guard any
319 # logged-in user could subscribe to any guessed/leaked research
320 # UUID and receive its progress events. The per-user encrypted DB
321 # is the ownership boundary: if the research row doesn't exist in
322 # the user's DB, they don't own it.
323 username = session.get("username")
324 if not username or not self._user_owns_research(username, research_id): 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true
325 self.__log_info(
326 f"Rejected subscribe from {request.sid}: user does not own research {research_id}"
327 )
328 return
330 with self.__lock:
331 if research_id not in self.__socket_subscriptions:
332 self.__socket_subscriptions[research_id] = set()
333 self.__socket_subscriptions[research_id].add(request.sid)
334 self.__log_info(
335 f"Client {request.sid} subscribed to research {research_id}"
336 )
338 # Send current status immediately if available in active research
339 snapshot = get_active_research_snapshot(research_id)
340 if snapshot is not None:
341 progress = snapshot["progress"]
342 latest_log = snapshot["log"][-1] if snapshot["log"] else None
344 if latest_log:
345 self.emit_socket_event(
346 f"progress_{research_id}",
347 {
348 "progress": progress,
349 "message": latest_log.get("message", "Processing..."),
350 "status": ResearchStatus.IN_PROGRESS,
351 "log_entry": latest_log,
352 },
353 room=request.sid,
354 )
356 @staticmethod
357 def _user_owns_research(username: str, research_id: str) -> bool:
358 """Return True if the given user owns this research / benchmark id.
360 Used as the authorization boundary for WebSocket subscriptions —
361 ownership is checked against the user's encrypted SQLite database,
362 which is the per-user data partition. A static helper so unit
363 tests can exercise the authz logic without standing up the
364 singleton/Flask app.
366 Recognizes both normal research (``ResearchHistory``, UUID id) and
367 benchmark runs (``BenchmarkRun``, integer id) — the benchmark page
368 subscribes with its ``BenchmarkRun.id``, which lives in the same
369 per-user DB. Both checks stay scoped to the caller's own database,
370 so no cross-user access is introduced.
371 """
372 try:
373 from ...database.session_context import get_user_db_session
374 from ...database.models import ResearchHistory
376 with get_user_db_session(username) as db:
377 if (
378 db.query(ResearchHistory.id)
379 .filter(ResearchHistory.id == research_id)
380 .first()
381 is not None
382 ):
383 return True
385 # Benchmark pages subscribe with their BenchmarkRun.id.
386 # Recognize the user's own benchmark runs so the ownership
387 # gate doesn't drop benchmark live progress (regression vs.
388 # the removed cross-user broadcast). research_id stays a
389 # string (never coerced to int — IDs are strings/UUIDs
390 # repo-wide); SQLite applies numeric affinity to match the
391 # Integer column. Only attempt this for numeric ids.
392 if str(research_id).isdigit():
393 from ...database.models.benchmark import BenchmarkRun
395 return (
396 db.query(BenchmarkRun.id)
397 .filter(BenchmarkRun.id == research_id)
398 .first()
399 is not None
400 )
401 return False
402 except Exception:
403 # Conservative: deny on any DB-open or query failure so a
404 # transient infra error never silently widens authz.
405 logger.opt(exception=True).warning(
406 "Failed to verify research ownership for socket subscribe"
407 )
408 return False
410 def __handle_unsubscribe(self, data, request):
411 """Handle client unsubscribe from research updates."""
412 research_id = (
413 data.get("research_id") if isinstance(data, dict) else None
414 )
415 if not research_id:
416 return
418 # Symmetric with __handle_subscribe: require the caller to own the
419 # research before mutating the per-research subscription set. The
420 # practical impact of an unguarded unsubscribe is small (no data
421 # exfiltration; subscribe is already guarded), but it keeps the
422 # authz boundary consistent and avoids log spam from spoofed sids.
423 username = session.get("username")
424 if not username or not self._user_owns_research(username, research_id):
425 self.__log_info(
426 f"Rejected unsubscribe from {request.sid}: user does not own research {research_id}"
427 )
428 return
430 with self.__lock:
431 subs = self.__socket_subscriptions.get(research_id)
432 if subs:
433 subs.discard(request.sid)
434 # Prune empty sets so the dict doesn't grow unbounded with
435 # stale research_ids over long server runtimes.
436 if not subs:
437 self.__socket_subscriptions.pop(research_id, None)
438 self.__log_info(
439 f"Client {request.sid} unsubscribed from research {research_id}"
440 )
442 def __handle_socket_error(self, e):
443 """Handle Socket.IO errors"""
444 self.__log_exception(f"Socket.IO error: {str(e)}")
445 # Don't propagate exceptions to avoid crashing the server
446 return False
448 def __handle_default_error(self, e):
449 """Handle unhandled Socket.IO errors"""
450 self.__log_exception(f"Unhandled Socket.IO error: {str(e)}")
451 # Don't propagate exceptions to avoid crashing the server
452 return False
454 def run(self, host: str, port: int, debug: bool = False) -> None:
455 """
456 Runs the SocketIO server.
458 Args:
459 host: The hostname to bind the server to.
460 port: The port number to listen on.
461 debug: Whether to run in debug mode. Defaults to False.
463 """
464 # Suppress Server header to prevent version information disclosure
465 # This must be done before starting the server because Werkzeug adds
466 # the header at the HTTP layer, not WSGI layer
467 try:
468 from werkzeug.serving import WSGIRequestHandler
470 WSGIRequestHandler.version_string = lambda self: "" # type: ignore[method-assign]
471 logger.debug("Suppressed Server header for security")
472 except ImportError:
473 logger.warning(
474 "Could not suppress Server header - werkzeug not found"
475 )
477 logger.info(f"Starting web server on {host}:{port} (debug: {debug})")
478 self.__socketio.run(
479 self.__app, # Use the stored Flask app reference
480 debug=debug,
481 host=host,
482 port=port,
483 allow_unsafe_werkzeug=True,
484 use_reloader=False,
485 )