Coverage for src / local_deep_research / web / services / socket_service.py: 97%
148 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1from threading import Lock
2from typing import Any
4from flask import Flask, request
5from flask_socketio import SocketIO
6from loguru import logger
8from ...constants import ResearchStatus
9from ..routes.globals import get_active_research_snapshot
12class SocketIOService:
13 """
14 Singleton class for managing SocketIO connections and subscriptions.
15 """
17 _instance = None
19 def __new__(cls, *args: Any, app: Flask | None = None, **kwargs: Any):
20 """
21 Args:
22 app: The Flask app to bind this service to. It must be specified
23 the first time this is called and the singleton instance is
24 created, but will be ignored after that.
25 *args: Arguments to pass to the superclass's __new__ method.
26 **kwargs: Keyword arguments to pass to the superclass's __new__ method.
27 """
28 if not cls._instance:
29 if app is None:
30 raise ValueError(
31 "Flask app must be specified to create a SocketIOService instance."
32 )
33 cls._instance = super(SocketIOService, cls).__new__(
34 cls, *args, **kwargs
35 )
36 cls._instance.__init_singleton(app)
37 return cls._instance
39 def __init_singleton(self, app: Flask) -> None:
40 """
41 Initializes the singleton instance.
43 Args:
44 app: The app to bind this service to.
46 """
47 self.__app = app # Store the Flask app reference
49 # Determine WebSocket CORS policy from env var or default
50 from ...settings.env_registry import get_env_setting
52 ws_origins_env = get_env_setting("security.websocket.allowed_origins")
53 socketio_cors: str | list[str] | None
54 if ws_origins_env is not None:
55 if ws_origins_env == "*":
56 socketio_cors = "*"
57 elif ws_origins_env:
58 socketio_cors = [o.strip() for o in ws_origins_env.split(",")]
59 else:
60 socketio_cors = None
61 else:
62 # No env var set — preserve existing permissive default
63 socketio_cors = "*"
65 if socketio_cors is None:
66 logger.info(
67 "Socket.IO CORS: same-origin only (set LDR_SECURITY_WEBSOCKET_ALLOWED_ORIGINS to configure)"
68 )
69 elif socketio_cors == "*":
70 logger.debug("Socket.IO CORS: all origins allowed")
71 else:
72 logger.info(f"Socket.IO CORS: restricted to {socketio_cors}")
74 self.__socketio = SocketIO(
75 app,
76 cors_allowed_origins=socketio_cors,
77 async_mode="threading",
78 path="/socket.io",
79 logger=False,
80 engineio_logger=False,
81 ping_timeout=20,
82 ping_interval=5,
83 )
85 # Socket subscription tracking.
86 self.__socket_subscriptions: dict[str, Any] = {}
87 # Set to false to disable logging in the event handlers. This can
88 # be necessary because it will sometimes run the handlers directly
89 # during a call to `emit` that was made in a logging handler.
90 self.__logging_enabled = True
91 # Protects access to shared state.
92 self.__lock = Lock()
94 # Register events.
95 @self.__socketio.on("connect")
96 def on_connect():
97 self.__handle_connect(request)
99 @self.__socketio.on("disconnect")
100 def on_disconnect(reason: str):
101 self.__handle_disconnect(request, reason)
103 @self.__socketio.on("subscribe_to_research")
104 def on_subscribe(data):
105 self.__handle_subscribe(data, request)
107 @self.__socketio.on_error
108 def on_error(e):
109 return self.__handle_socket_error(e)
111 @self.__socketio.on_error_default
112 def on_default_error(e):
113 return self.__handle_default_error(e)
115 def __log_info(self, message: str, *args: Any, **kwargs: Any) -> None:
116 """Log an info message."""
117 if self.__logging_enabled:
118 logger.info(message, *args, **kwargs)
120 def __log_error(self, message: str, *args: Any, **kwargs: Any) -> None:
121 """Log an error message."""
122 if self.__logging_enabled:
123 logger.error(message, *args, **kwargs)
125 def __log_exception(self, message: str, *args: Any, **kwargs: Any) -> None:
126 """Log an exception."""
127 if self.__logging_enabled:
128 logger.exception(message, *args, **kwargs)
130 def emit_socket_event(self, event, data, room=None):
131 """
132 Emit a socket event to clients.
134 Args:
135 event: The event name to emit
136 data: The data to send with the event
137 room: Optional room ID to send to specific client
139 Returns:
140 bool: True if emission was successful, False otherwise
141 """
142 try:
143 # If room is specified, only emit to that room
144 if room:
145 self.__socketio.emit(event, data, room=room)
146 else:
147 # Otherwise broadcast to all
148 self.__socketio.emit(event, data)
149 return True
150 except Exception:
151 logger.exception(f"Error emitting socket event {event}")
152 return False
154 def emit_to_subscribers(
155 self, event_base, research_id, data, enable_logging: bool = True
156 ):
157 """
158 Emit an event to all subscribers of a specific research.
160 Args:
161 event_base: Base event name (will be formatted with research_id)
162 research_id: ID of the research
163 data: The data to send with the event
164 enable_logging: If set to false, this will disable all logging,
165 which is useful if we are calling this inside of a logging
166 handler.
168 Returns:
169 bool: True if emission was successful, False otherwise
171 """
172 if not enable_logging:
173 self.__logging_enabled = False
175 try:
176 full_event = f"{event_base}_{research_id}"
178 # Emit only to specific subscribers (no broadcast) to avoid
179 # duplicate messages and reduce server load under concurrency
180 with self.__lock:
181 subscriptions = self.__socket_subscriptions.get(research_id)
182 if subscriptions:
183 subscriptions = (
184 subscriptions.copy()
185 ) # snapshot avoids RuntimeError
186 else:
187 subscriptions = None
188 if subscriptions is not None:
189 for sid in subscriptions:
190 try:
191 self.__socketio.emit(full_event, data, room=sid)
192 except Exception:
193 self.__log_exception(
194 f"Error emitting to subscriber {sid}"
195 )
196 else:
197 # No targeted subscribers yet — broadcast so early
198 # listeners still receive the event
199 self.__socketio.emit(full_event, data)
201 return True
202 except Exception:
203 self.__log_exception(
204 f"Error emitting to subscribers for research {research_id}"
205 )
206 return False
207 finally:
208 self.__logging_enabled = True
210 def remove_subscriptions_for_research(self, research_id: str) -> None:
211 """Remove all socket subscriptions for a completed research."""
212 with self.__lock:
213 removed = self.__socket_subscriptions.pop(research_id, None)
214 if removed is not None:
215 self.__log_info(
216 f"Removed {len(removed)} subscription(s) for research {research_id}"
217 )
219 def __handle_connect(self, request):
220 """Handle client connection"""
221 self.__log_info(f"Client connected: {request.sid}")
223 def __handle_disconnect(self, request, reason: str):
224 """Handle client disconnection"""
225 try:
226 self.__log_info(
227 f"Client {request.sid} disconnected because: {reason}"
228 )
229 # Clean up subscriptions for this client.
230 # __socket_subscriptions is keyed by research_id → set of sids,
231 # so we iterate all entries and discard the disconnecting sid.
232 with self.__lock:
233 empty_keys = []
234 for research_id, sids in self.__socket_subscriptions.items():
235 sids.discard(request.sid)
236 if not sids:
237 empty_keys.append(research_id)
238 for key in empty_keys:
239 del self.__socket_subscriptions[key]
240 self.__log_info(f"Removed subscription for client {request.sid}")
242 # Clean up any thread-local database sessions that may have been
243 # created during socket handler execution. This prevents file
244 # descriptor leaks from unclosed SQLAlchemy sessions.
245 try:
246 from ...database.thread_local_session import (
247 cleanup_current_thread,
248 )
250 cleanup_current_thread()
251 except ImportError:
252 pass # Module not available, skip cleanup
253 except Exception:
254 self.__log_exception(
255 "Error cleaning up thread session on disconnect"
256 )
257 except Exception as e:
258 self.__log_exception(f"Error handling disconnect: {e}")
260 def __handle_subscribe(self, data, request):
261 """Handle client subscription to research updates"""
262 research_id = data.get("research_id")
263 if research_id:
264 # Initialize subscription set if needed
265 with self.__lock:
266 if research_id not in self.__socket_subscriptions: 266 ↛ 270line 266 didn't jump to line 270 because the condition on line 266 was always true
267 self.__socket_subscriptions[research_id] = set()
269 # Add this client to the subscribers
270 self.__socket_subscriptions[research_id].add(request.sid)
271 self.__log_info(
272 f"Client {request.sid} subscribed to research {research_id}"
273 )
275 # Send current status immediately if available in active research
276 snapshot = get_active_research_snapshot(research_id)
277 if snapshot is not None:
278 progress = snapshot["progress"]
279 latest_log = snapshot["log"][-1] if snapshot["log"] else None
281 if latest_log:
282 self.emit_socket_event(
283 f"progress_{research_id}",
284 {
285 "progress": progress,
286 "message": latest_log.get(
287 "message", "Processing..."
288 ),
289 "status": ResearchStatus.IN_PROGRESS,
290 "log_entry": latest_log,
291 },
292 room=request.sid,
293 )
295 def __handle_socket_error(self, e):
296 """Handle Socket.IO errors"""
297 self.__log_exception(f"Socket.IO error: {str(e)}")
298 # Don't propagate exceptions to avoid crashing the server
299 return False
301 def __handle_default_error(self, e):
302 """Handle unhandled Socket.IO errors"""
303 self.__log_exception(f"Unhandled Socket.IO error: {str(e)}")
304 # Don't propagate exceptions to avoid crashing the server
305 return False
307 def run(self, host: str, port: int, debug: bool = False) -> None:
308 """
309 Runs the SocketIO server.
311 Args:
312 host: The hostname to bind the server to.
313 port: The port number to listen on.
314 debug: Whether to run in debug mode. Defaults to False.
316 """
317 # Suppress Server header to prevent version information disclosure
318 # This must be done before starting the server because Werkzeug adds
319 # the header at the HTTP layer, not WSGI layer
320 try:
321 from werkzeug.serving import WSGIRequestHandler
323 WSGIRequestHandler.version_string = lambda self: "" # type: ignore[method-assign]
324 logger.debug("Suppressed Server header for security")
325 except ImportError:
326 logger.warning(
327 "Could not suppress Server header - werkzeug not found"
328 )
330 logger.info(f"Starting web server on {host}:{port} (debug: {debug})")
331 self.__socketio.run(
332 self.__app, # Use the stored Flask app reference
333 debug=debug,
334 host=host,
335 port=port,
336 allow_unsafe_werkzeug=True,
337 use_reloader=False,
338 )