Coverage for src / local_deep_research / web / services / socket_service.py: 80%
143 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
1from threading import Lock
2from typing import Any, NoReturn
4from flask import Flask, request
5from flask_socketio import SocketIO
6from loguru import logger
8from ...constants import ResearchStatus
9from ..routes.globals import get_globals
12class SocketIOService:
13 """
14 Singleton class for managing SocketIO connections and subscriptions.
15 """
17 _instance = None
19 def __new__(cls, *args: Any, app: Flask | None = None, **kwargs: Any):
20 """
21 Args:
22 app: The Flask app to bind this service to. It must be specified
23 the first time this is called and the singleton instance is
24 created, but will be ignored after that.
25 *args: Arguments to pass to the superclass's __new__ method.
26 **kwargs: Keyword arguments to pass to the superclass's __new__ method.
27 """
28 if not cls._instance:
29 if app is None:
30 raise ValueError(
31 "Flask app must be specified to create a SocketIOService instance."
32 )
33 cls._instance = super(SocketIOService, cls).__new__(
34 cls, *args, **kwargs
35 )
36 cls._instance.__init_singleton(app)
37 return cls._instance
39 def __init_singleton(self, app: Flask) -> None:
40 """
41 Initializes the singleton instance.
43 Args:
44 app: The app to bind this service to.
46 """
47 self.__app = app # Store the Flask app reference
49 # Determine WebSocket CORS policy from env var or default
50 from ...settings.env_registry import get_env_setting
52 ws_origins_env = get_env_setting("security.websocket.allowed_origins")
53 if ws_origins_env is not None: 53 ↛ 54line 53 didn't jump to line 54 because the condition on line 53 was never true
54 if ws_origins_env == "*":
55 socketio_cors = "*"
56 elif ws_origins_env:
57 socketio_cors = [o.strip() for o in ws_origins_env.split(",")]
58 else:
59 socketio_cors = None
60 else:
61 # No env var set — preserve existing permissive default
62 socketio_cors = "*"
64 if socketio_cors is None: 64 ↛ 65line 64 didn't jump to line 65 because the condition on line 64 was never true
65 logger.info(
66 "Socket.IO CORS: same-origin only (set LDR_SECURITY_WEBSOCKET_ALLOWED_ORIGINS to configure)"
67 )
68 elif socketio_cors == "*": 68 ↛ 71line 68 didn't jump to line 71 because the condition on line 68 was always true
69 logger.debug("Socket.IO CORS: all origins allowed")
70 else:
71 logger.info(f"Socket.IO CORS: restricted to {socketio_cors}")
73 self.__socketio = SocketIO(
74 app,
75 cors_allowed_origins=socketio_cors,
76 async_mode="threading",
77 path="/socket.io",
78 logger=False,
79 engineio_logger=False,
80 ping_timeout=20,
81 ping_interval=5,
82 )
84 # Socket subscription tracking.
85 self.__socket_subscriptions = {}
86 # Set to false to disable logging in the event handlers. This can
87 # be necessary because it will sometimes run the handlers directly
88 # during a call to `emit` that was made in a logging handler.
89 self.__logging_enabled = True
90 # Protects access to shared state.
91 self.__lock = Lock()
93 # Register events.
94 @self.__socketio.on("connect")
95 def on_connect():
96 self.__handle_connect(request)
98 @self.__socketio.on("disconnect")
99 def on_disconnect(reason: str):
100 self.__handle_disconnect(request, reason)
102 @self.__socketio.on("subscribe_to_research")
103 def on_subscribe(data):
104 globals_dict = get_globals()
105 active_research = globals_dict.get("active_research", {})
106 self.__handle_subscribe(data, request, active_research)
108 @self.__socketio.on_error
109 def on_error(e):
110 return self.__handle_socket_error(e)
112 @self.__socketio.on_error_default
113 def on_default_error(e):
114 return self.__handle_default_error(e)
116 def __log_info(self, message: str, *args: Any, **kwargs: Any) -> None:
117 """Log an info message."""
118 if self.__logging_enabled:
119 logger.info(message, *args, **kwargs)
121 def __log_error(self, message: str, *args: Any, **kwargs: Any) -> None:
122 """Log an error message."""
123 if self.__logging_enabled:
124 logger.error(message, *args, **kwargs)
126 def __log_exception(self, message: str, *args: Any, **kwargs: Any) -> None:
127 """Log an exception."""
128 if self.__logging_enabled:
129 logger.exception(message, *args, **kwargs)
131 def emit_socket_event(self, event, data, room=None):
132 """
133 Emit a socket event to clients.
135 Args:
136 event: The event name to emit
137 data: The data to send with the event
138 room: Optional room ID to send to specific client
140 Returns:
141 bool: True if emission was successful, False otherwise
142 """
143 try:
144 # If room is specified, only emit to that room
145 if room:
146 self.__socketio.emit(event, data, room=room)
147 else:
148 # Otherwise broadcast to all
149 self.__socketio.emit(event, data)
150 return True
151 except Exception:
152 logger.exception(f"Error emitting socket event {event}")
153 return False
155 def emit_to_subscribers(
156 self, event_base, research_id, data, enable_logging: bool = True
157 ):
158 """
159 Emit an event to all subscribers of a specific research.
161 Args:
162 event_base: Base event name (will be formatted with research_id)
163 research_id: ID of the research
164 data: The data to send with the event
165 enable_logging: If set to false, this will disable all logging,
166 which is useful if we are calling this inside of a logging
167 handler.
169 Returns:
170 bool: True if emission was successful, False otherwise
172 """
173 if not enable_logging:
174 self.__logging_enabled = False
176 try:
177 # Emit to the general channel for the research
178 full_event = f"{event_base}_{research_id}"
179 self.__socketio.emit(full_event, data)
181 # Emit to specific subscribers
182 with self.__lock:
183 subscriptions = self.__socket_subscriptions.get(research_id)
184 if subscriptions is not None:
185 subscriptions = (
186 subscriptions.copy()
187 ) # snapshot avoids RuntimeError
188 if subscriptions is not None:
189 for sid in subscriptions:
190 try:
191 self.__socketio.emit(full_event, data, room=sid)
192 except Exception:
193 self.__log_exception(
194 f"Error emitting to subscriber {sid}"
195 )
197 return True
198 except Exception:
199 self.__log_exception(
200 f"Error emitting to subscribers for research {research_id}"
201 )
202 return False
203 finally:
204 self.__logging_enabled = True
206 def __handle_connect(self, request):
207 """Handle client connection"""
208 self.__log_info(f"Client connected: {request.sid}")
210 def __handle_disconnect(self, request, reason: str):
211 """Handle client disconnection"""
212 try:
213 self.__log_info(
214 f"Client {request.sid} disconnected because: {reason}"
215 )
216 # Clean up subscriptions for this client.
217 # __socket_subscriptions is keyed by research_id → set of sids,
218 # so we iterate all entries and discard the disconnecting sid.
219 with self.__lock:
220 empty_keys = []
221 for research_id, sids in self.__socket_subscriptions.items():
222 sids.discard(request.sid)
223 if not sids:
224 empty_keys.append(research_id)
225 for key in empty_keys:
226 del self.__socket_subscriptions[key]
227 self.__log_info(f"Removed subscription for client {request.sid}")
229 # Clean up any thread-local database sessions that may have been
230 # created during socket handler execution. This prevents file
231 # descriptor leaks from unclosed SQLAlchemy sessions.
232 try:
233 from ...database.thread_local_session import (
234 cleanup_current_thread,
235 )
237 cleanup_current_thread()
238 except ImportError:
239 pass # Module not available, skip cleanup
240 except Exception:
241 self.__log_exception(
242 "Error cleaning up thread session on disconnect"
243 )
244 except Exception as e:
245 self.__log_exception(f"Error handling disconnect: {e}")
247 def __handle_subscribe(self, data, request, active_research=None):
248 """Handle client subscription to research updates"""
249 research_id = data.get("research_id")
250 if research_id:
251 # Initialize subscription set if needed
252 with self.__lock:
253 if research_id not in self.__socket_subscriptions: 253 ↛ 257line 253 didn't jump to line 257 because the condition on line 253 was always true
254 self.__socket_subscriptions[research_id] = set()
256 # Add this client to the subscribers
257 self.__socket_subscriptions[research_id].add(request.sid)
258 self.__log_info(
259 f"Client {request.sid} subscribed to research {research_id}"
260 )
262 # Send current status immediately if available in active research
263 if active_research and research_id in active_research: 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true
264 progress = active_research[research_id]["progress"]
265 latest_log = (
266 active_research[research_id]["log"][-1]
267 if active_research[research_id]["log"]
268 else None
269 )
271 if latest_log:
272 self.emit_socket_event(
273 f"research_progress_{research_id}",
274 {
275 "progress": progress,
276 "message": latest_log.get(
277 "message", "Processing..."
278 ),
279 "status": ResearchStatus.IN_PROGRESS,
280 "log_entry": latest_log,
281 },
282 room=request.sid,
283 )
285 def __handle_socket_error(self, e):
286 """Handle Socket.IO errors"""
287 self.__log_exception(f"Socket.IO error: {str(e)}")
288 # Don't propagate exceptions to avoid crashing the server
289 return False
291 def __handle_default_error(self, e):
292 """Handle unhandled Socket.IO errors"""
293 self.__log_exception(f"Unhandled Socket.IO error: {str(e)}")
294 # Don't propagate exceptions to avoid crashing the server
295 return False
297 def run(self, host: str, port: int, debug: bool = False) -> NoReturn:
298 """
299 Runs the SocketIO server.
301 Args:
302 host: The hostname to bind the server to.
303 port: The port number to listen on.
304 debug: Whether to run in debug mode. Defaults to False.
306 """
307 # Suppress Server header to prevent version information disclosure
308 # This must be done before starting the server because Werkzeug adds
309 # the header at the HTTP layer, not WSGI layer
310 try:
311 from werkzeug.serving import WSGIRequestHandler
313 WSGIRequestHandler.version_string = lambda self: ""
314 logger.debug("Suppressed Server header for security")
315 except ImportError:
316 logger.warning(
317 "Could not suppress Server header - werkzeug not found"
318 )
320 logger.info(f"Starting web server on {host}:{port} (debug: {debug})")
321 self.__socketio.run(
322 self.__app, # Use the stored Flask app reference
323 debug=debug,
324 host=host,
325 port=port,
326 allow_unsafe_werkzeug=True,
327 use_reloader=False,
328 )