Coverage for src / local_deep_research / web / services / socket_service.py: 97%

148 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1from threading import Lock 

2from typing import Any 

3 

4from flask import Flask, request 

5from flask_socketio import SocketIO 

6from loguru import logger 

7 

8from ...constants import ResearchStatus 

9from ..routes.globals import get_active_research_snapshot 

10 

11 

12class SocketIOService: 

13 """ 

14 Singleton class for managing SocketIO connections and subscriptions. 

15 """ 

16 

17 _instance = None 

18 

19 def __new__(cls, *args: Any, app: Flask | None = None, **kwargs: Any): 

20 """ 

21 Args: 

22 app: The Flask app to bind this service to. It must be specified 

23 the first time this is called and the singleton instance is 

24 created, but will be ignored after that. 

25 *args: Arguments to pass to the superclass's __new__ method. 

26 **kwargs: Keyword arguments to pass to the superclass's __new__ method. 

27 """ 

28 if not cls._instance: 

29 if app is None: 

30 raise ValueError( 

31 "Flask app must be specified to create a SocketIOService instance." 

32 ) 

33 cls._instance = super(SocketIOService, cls).__new__( 

34 cls, *args, **kwargs 

35 ) 

36 cls._instance.__init_singleton(app) 

37 return cls._instance 

38 

39 def __init_singleton(self, app: Flask) -> None: 

40 """ 

41 Initializes the singleton instance. 

42 

43 Args: 

44 app: The app to bind this service to. 

45 

46 """ 

47 self.__app = app # Store the Flask app reference 

48 

49 # Determine WebSocket CORS policy from env var or default 

50 from ...settings.env_registry import get_env_setting 

51 

52 ws_origins_env = get_env_setting("security.websocket.allowed_origins") 

53 socketio_cors: str | list[str] | None 

54 if ws_origins_env is not None: 

55 if ws_origins_env == "*": 

56 socketio_cors = "*" 

57 elif ws_origins_env: 

58 socketio_cors = [o.strip() for o in ws_origins_env.split(",")] 

59 else: 

60 socketio_cors = None 

61 else: 

62 # No env var set — preserve existing permissive default 

63 socketio_cors = "*" 

64 

65 if socketio_cors is None: 

66 logger.info( 

67 "Socket.IO CORS: same-origin only (set LDR_SECURITY_WEBSOCKET_ALLOWED_ORIGINS to configure)" 

68 ) 

69 elif socketio_cors == "*": 

70 logger.debug("Socket.IO CORS: all origins allowed") 

71 else: 

72 logger.info(f"Socket.IO CORS: restricted to {socketio_cors}") 

73 

74 self.__socketio = SocketIO( 

75 app, 

76 cors_allowed_origins=socketio_cors, 

77 async_mode="threading", 

78 path="/socket.io", 

79 logger=False, 

80 engineio_logger=False, 

81 ping_timeout=20, 

82 ping_interval=5, 

83 ) 

84 

85 # Socket subscription tracking. 

86 self.__socket_subscriptions: dict[str, Any] = {} 

87 # Set to false to disable logging in the event handlers. This can 

88 # be necessary because it will sometimes run the handlers directly 

89 # during a call to `emit` that was made in a logging handler. 

90 self.__logging_enabled = True 

91 # Protects access to shared state. 

92 self.__lock = Lock() 

93 

94 # Register events. 

95 @self.__socketio.on("connect") 

96 def on_connect(): 

97 self.__handle_connect(request) 

98 

99 @self.__socketio.on("disconnect") 

100 def on_disconnect(reason: str): 

101 self.__handle_disconnect(request, reason) 

102 

103 @self.__socketio.on("subscribe_to_research") 

104 def on_subscribe(data): 

105 self.__handle_subscribe(data, request) 

106 

107 @self.__socketio.on_error 

108 def on_error(e): 

109 return self.__handle_socket_error(e) 

110 

111 @self.__socketio.on_error_default 

112 def on_default_error(e): 

113 return self.__handle_default_error(e) 

114 

115 def __log_info(self, message: str, *args: Any, **kwargs: Any) -> None: 

116 """Log an info message.""" 

117 if self.__logging_enabled: 

118 logger.info(message, *args, **kwargs) 

119 

120 def __log_error(self, message: str, *args: Any, **kwargs: Any) -> None: 

121 """Log an error message.""" 

122 if self.__logging_enabled: 

123 logger.error(message, *args, **kwargs) 

124 

125 def __log_exception(self, message: str, *args: Any, **kwargs: Any) -> None: 

126 """Log an exception.""" 

127 if self.__logging_enabled: 

128 logger.exception(message, *args, **kwargs) 

129 

130 def emit_socket_event(self, event, data, room=None): 

131 """ 

132 Emit a socket event to clients. 

133 

134 Args: 

135 event: The event name to emit 

136 data: The data to send with the event 

137 room: Optional room ID to send to specific client 

138 

139 Returns: 

140 bool: True if emission was successful, False otherwise 

141 """ 

142 try: 

143 # If room is specified, only emit to that room 

144 if room: 

145 self.__socketio.emit(event, data, room=room) 

146 else: 

147 # Otherwise broadcast to all 

148 self.__socketio.emit(event, data) 

149 return True 

150 except Exception: 

151 logger.exception(f"Error emitting socket event {event}") 

152 return False 

153 

154 def emit_to_subscribers( 

155 self, event_base, research_id, data, enable_logging: bool = True 

156 ): 

157 """ 

158 Emit an event to all subscribers of a specific research. 

159 

160 Args: 

161 event_base: Base event name (will be formatted with research_id) 

162 research_id: ID of the research 

163 data: The data to send with the event 

164 enable_logging: If set to false, this will disable all logging, 

165 which is useful if we are calling this inside of a logging 

166 handler. 

167 

168 Returns: 

169 bool: True if emission was successful, False otherwise 

170 

171 """ 

172 if not enable_logging: 

173 self.__logging_enabled = False 

174 

175 try: 

176 full_event = f"{event_base}_{research_id}" 

177 

178 # Emit only to specific subscribers (no broadcast) to avoid 

179 # duplicate messages and reduce server load under concurrency 

180 with self.__lock: 

181 subscriptions = self.__socket_subscriptions.get(research_id) 

182 if subscriptions: 

183 subscriptions = ( 

184 subscriptions.copy() 

185 ) # snapshot avoids RuntimeError 

186 else: 

187 subscriptions = None 

188 if subscriptions is not None: 

189 for sid in subscriptions: 

190 try: 

191 self.__socketio.emit(full_event, data, room=sid) 

192 except Exception: 

193 self.__log_exception( 

194 f"Error emitting to subscriber {sid}" 

195 ) 

196 else: 

197 # No targeted subscribers yet — broadcast so early 

198 # listeners still receive the event 

199 self.__socketio.emit(full_event, data) 

200 

201 return True 

202 except Exception: 

203 self.__log_exception( 

204 f"Error emitting to subscribers for research {research_id}" 

205 ) 

206 return False 

207 finally: 

208 self.__logging_enabled = True 

209 

210 def remove_subscriptions_for_research(self, research_id: str) -> None: 

211 """Remove all socket subscriptions for a completed research.""" 

212 with self.__lock: 

213 removed = self.__socket_subscriptions.pop(research_id, None) 

214 if removed is not None: 

215 self.__log_info( 

216 f"Removed {len(removed)} subscription(s) for research {research_id}" 

217 ) 

218 

219 def __handle_connect(self, request): 

220 """Handle client connection""" 

221 self.__log_info(f"Client connected: {request.sid}") 

222 

223 def __handle_disconnect(self, request, reason: str): 

224 """Handle client disconnection""" 

225 try: 

226 self.__log_info( 

227 f"Client {request.sid} disconnected because: {reason}" 

228 ) 

229 # Clean up subscriptions for this client. 

230 # __socket_subscriptions is keyed by research_id → set of sids, 

231 # so we iterate all entries and discard the disconnecting sid. 

232 with self.__lock: 

233 empty_keys = [] 

234 for research_id, sids in self.__socket_subscriptions.items(): 

235 sids.discard(request.sid) 

236 if not sids: 

237 empty_keys.append(research_id) 

238 for key in empty_keys: 

239 del self.__socket_subscriptions[key] 

240 self.__log_info(f"Removed subscription for client {request.sid}") 

241 

242 # Clean up any thread-local database sessions that may have been 

243 # created during socket handler execution. This prevents file 

244 # descriptor leaks from unclosed SQLAlchemy sessions. 

245 try: 

246 from ...database.thread_local_session import ( 

247 cleanup_current_thread, 

248 ) 

249 

250 cleanup_current_thread() 

251 except ImportError: 

252 pass # Module not available, skip cleanup 

253 except Exception: 

254 self.__log_exception( 

255 "Error cleaning up thread session on disconnect" 

256 ) 

257 except Exception as e: 

258 self.__log_exception(f"Error handling disconnect: {e}") 

259 

260 def __handle_subscribe(self, data, request): 

261 """Handle client subscription to research updates""" 

262 research_id = data.get("research_id") 

263 if research_id: 

264 # Initialize subscription set if needed 

265 with self.__lock: 

266 if research_id not in self.__socket_subscriptions: 266 ↛ 270line 266 didn't jump to line 270 because the condition on line 266 was always true

267 self.__socket_subscriptions[research_id] = set() 

268 

269 # Add this client to the subscribers 

270 self.__socket_subscriptions[research_id].add(request.sid) 

271 self.__log_info( 

272 f"Client {request.sid} subscribed to research {research_id}" 

273 ) 

274 

275 # Send current status immediately if available in active research 

276 snapshot = get_active_research_snapshot(research_id) 

277 if snapshot is not None: 

278 progress = snapshot["progress"] 

279 latest_log = snapshot["log"][-1] if snapshot["log"] else None 

280 

281 if latest_log: 

282 self.emit_socket_event( 

283 f"progress_{research_id}", 

284 { 

285 "progress": progress, 

286 "message": latest_log.get( 

287 "message", "Processing..." 

288 ), 

289 "status": ResearchStatus.IN_PROGRESS, 

290 "log_entry": latest_log, 

291 }, 

292 room=request.sid, 

293 ) 

294 

295 def __handle_socket_error(self, e): 

296 """Handle Socket.IO errors""" 

297 self.__log_exception(f"Socket.IO error: {str(e)}") 

298 # Don't propagate exceptions to avoid crashing the server 

299 return False 

300 

301 def __handle_default_error(self, e): 

302 """Handle unhandled Socket.IO errors""" 

303 self.__log_exception(f"Unhandled Socket.IO error: {str(e)}") 

304 # Don't propagate exceptions to avoid crashing the server 

305 return False 

306 

307 def run(self, host: str, port: int, debug: bool = False) -> None: 

308 """ 

309 Runs the SocketIO server. 

310 

311 Args: 

312 host: The hostname to bind the server to. 

313 port: The port number to listen on. 

314 debug: Whether to run in debug mode. Defaults to False. 

315 

316 """ 

317 # Suppress Server header to prevent version information disclosure 

318 # This must be done before starting the server because Werkzeug adds 

319 # the header at the HTTP layer, not WSGI layer 

320 try: 

321 from werkzeug.serving import WSGIRequestHandler 

322 

323 WSGIRequestHandler.version_string = lambda self: "" # type: ignore[method-assign] 

324 logger.debug("Suppressed Server header for security") 

325 except ImportError: 

326 logger.warning( 

327 "Could not suppress Server header - werkzeug not found" 

328 ) 

329 

330 logger.info(f"Starting web server on {host}:{port} (debug: {debug})") 

331 self.__socketio.run( 

332 self.__app, # Use the stored Flask app reference 

333 debug=debug, 

334 host=host, 

335 port=port, 

336 allow_unsafe_werkzeug=True, 

337 use_reloader=False, 

338 )