Coverage for src / local_deep_research / web / services / socket_service.py: 80%

143 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 01:07 +0000

1from threading import Lock 

2from typing import Any, NoReturn 

3 

4from flask import Flask, request 

5from flask_socketio import SocketIO 

6from loguru import logger 

7 

8from ...constants import ResearchStatus 

9from ..routes.globals import get_globals 

10 

11 

12class SocketIOService: 

13 """ 

14 Singleton class for managing SocketIO connections and subscriptions. 

15 """ 

16 

17 _instance = None 

18 

19 def __new__(cls, *args: Any, app: Flask | None = None, **kwargs: Any): 

20 """ 

21 Args: 

22 app: The Flask app to bind this service to. It must be specified 

23 the first time this is called and the singleton instance is 

24 created, but will be ignored after that. 

25 *args: Arguments to pass to the superclass's __new__ method. 

26 **kwargs: Keyword arguments to pass to the superclass's __new__ method. 

27 """ 

28 if not cls._instance: 

29 if app is None: 

30 raise ValueError( 

31 "Flask app must be specified to create a SocketIOService instance." 

32 ) 

33 cls._instance = super(SocketIOService, cls).__new__( 

34 cls, *args, **kwargs 

35 ) 

36 cls._instance.__init_singleton(app) 

37 return cls._instance 

38 

39 def __init_singleton(self, app: Flask) -> None: 

40 """ 

41 Initializes the singleton instance. 

42 

43 Args: 

44 app: The app to bind this service to. 

45 

46 """ 

47 self.__app = app # Store the Flask app reference 

48 

49 # Determine WebSocket CORS policy from env var or default 

50 from ...settings.env_registry import get_env_setting 

51 

52 ws_origins_env = get_env_setting("security.websocket.allowed_origins") 

53 if ws_origins_env is not None: 53 ↛ 54line 53 didn't jump to line 54 because the condition on line 53 was never true

54 if ws_origins_env == "*": 

55 socketio_cors = "*" 

56 elif ws_origins_env: 

57 socketio_cors = [o.strip() for o in ws_origins_env.split(",")] 

58 else: 

59 socketio_cors = None 

60 else: 

61 # No env var set — preserve existing permissive default 

62 socketio_cors = "*" 

63 

64 if socketio_cors is None: 64 ↛ 65line 64 didn't jump to line 65 because the condition on line 64 was never true

65 logger.info( 

66 "Socket.IO CORS: same-origin only (set LDR_SECURITY_WEBSOCKET_ALLOWED_ORIGINS to configure)" 

67 ) 

68 elif socketio_cors == "*": 68 ↛ 71line 68 didn't jump to line 71 because the condition on line 68 was always true

69 logger.debug("Socket.IO CORS: all origins allowed") 

70 else: 

71 logger.info(f"Socket.IO CORS: restricted to {socketio_cors}") 

72 

73 self.__socketio = SocketIO( 

74 app, 

75 cors_allowed_origins=socketio_cors, 

76 async_mode="threading", 

77 path="/socket.io", 

78 logger=False, 

79 engineio_logger=False, 

80 ping_timeout=20, 

81 ping_interval=5, 

82 ) 

83 

84 # Socket subscription tracking. 

85 self.__socket_subscriptions = {} 

86 # Set to false to disable logging in the event handlers. This can 

87 # be necessary because it will sometimes run the handlers directly 

88 # during a call to `emit` that was made in a logging handler. 

89 self.__logging_enabled = True 

90 # Protects access to shared state. 

91 self.__lock = Lock() 

92 

93 # Register events. 

94 @self.__socketio.on("connect") 

95 def on_connect(): 

96 self.__handle_connect(request) 

97 

98 @self.__socketio.on("disconnect") 

99 def on_disconnect(reason: str): 

100 self.__handle_disconnect(request, reason) 

101 

102 @self.__socketio.on("subscribe_to_research") 

103 def on_subscribe(data): 

104 globals_dict = get_globals() 

105 active_research = globals_dict.get("active_research", {}) 

106 self.__handle_subscribe(data, request, active_research) 

107 

108 @self.__socketio.on_error 

109 def on_error(e): 

110 return self.__handle_socket_error(e) 

111 

112 @self.__socketio.on_error_default 

113 def on_default_error(e): 

114 return self.__handle_default_error(e) 

115 

116 def __log_info(self, message: str, *args: Any, **kwargs: Any) -> None: 

117 """Log an info message.""" 

118 if self.__logging_enabled: 

119 logger.info(message, *args, **kwargs) 

120 

121 def __log_error(self, message: str, *args: Any, **kwargs: Any) -> None: 

122 """Log an error message.""" 

123 if self.__logging_enabled: 

124 logger.error(message, *args, **kwargs) 

125 

126 def __log_exception(self, message: str, *args: Any, **kwargs: Any) -> None: 

127 """Log an exception.""" 

128 if self.__logging_enabled: 

129 logger.exception(message, *args, **kwargs) 

130 

131 def emit_socket_event(self, event, data, room=None): 

132 """ 

133 Emit a socket event to clients. 

134 

135 Args: 

136 event: The event name to emit 

137 data: The data to send with the event 

138 room: Optional room ID to send to specific client 

139 

140 Returns: 

141 bool: True if emission was successful, False otherwise 

142 """ 

143 try: 

144 # If room is specified, only emit to that room 

145 if room: 

146 self.__socketio.emit(event, data, room=room) 

147 else: 

148 # Otherwise broadcast to all 

149 self.__socketio.emit(event, data) 

150 return True 

151 except Exception: 

152 logger.exception(f"Error emitting socket event {event}") 

153 return False 

154 

155 def emit_to_subscribers( 

156 self, event_base, research_id, data, enable_logging: bool = True 

157 ): 

158 """ 

159 Emit an event to all subscribers of a specific research. 

160 

161 Args: 

162 event_base: Base event name (will be formatted with research_id) 

163 research_id: ID of the research 

164 data: The data to send with the event 

165 enable_logging: If set to false, this will disable all logging, 

166 which is useful if we are calling this inside of a logging 

167 handler. 

168 

169 Returns: 

170 bool: True if emission was successful, False otherwise 

171 

172 """ 

173 if not enable_logging: 

174 self.__logging_enabled = False 

175 

176 try: 

177 # Emit to the general channel for the research 

178 full_event = f"{event_base}_{research_id}" 

179 self.__socketio.emit(full_event, data) 

180 

181 # Emit to specific subscribers 

182 with self.__lock: 

183 subscriptions = self.__socket_subscriptions.get(research_id) 

184 if subscriptions is not None: 

185 subscriptions = ( 

186 subscriptions.copy() 

187 ) # snapshot avoids RuntimeError 

188 if subscriptions is not None: 

189 for sid in subscriptions: 

190 try: 

191 self.__socketio.emit(full_event, data, room=sid) 

192 except Exception: 

193 self.__log_exception( 

194 f"Error emitting to subscriber {sid}" 

195 ) 

196 

197 return True 

198 except Exception: 

199 self.__log_exception( 

200 f"Error emitting to subscribers for research {research_id}" 

201 ) 

202 return False 

203 finally: 

204 self.__logging_enabled = True 

205 

206 def __handle_connect(self, request): 

207 """Handle client connection""" 

208 self.__log_info(f"Client connected: {request.sid}") 

209 

210 def __handle_disconnect(self, request, reason: str): 

211 """Handle client disconnection""" 

212 try: 

213 self.__log_info( 

214 f"Client {request.sid} disconnected because: {reason}" 

215 ) 

216 # Clean up subscriptions for this client. 

217 # __socket_subscriptions is keyed by research_id → set of sids, 

218 # so we iterate all entries and discard the disconnecting sid. 

219 with self.__lock: 

220 empty_keys = [] 

221 for research_id, sids in self.__socket_subscriptions.items(): 

222 sids.discard(request.sid) 

223 if not sids: 

224 empty_keys.append(research_id) 

225 for key in empty_keys: 

226 del self.__socket_subscriptions[key] 

227 self.__log_info(f"Removed subscription for client {request.sid}") 

228 

229 # Clean up any thread-local database sessions that may have been 

230 # created during socket handler execution. This prevents file 

231 # descriptor leaks from unclosed SQLAlchemy sessions. 

232 try: 

233 from ...database.thread_local_session import ( 

234 cleanup_current_thread, 

235 ) 

236 

237 cleanup_current_thread() 

238 except ImportError: 

239 pass # Module not available, skip cleanup 

240 except Exception: 

241 self.__log_exception( 

242 "Error cleaning up thread session on disconnect" 

243 ) 

244 except Exception as e: 

245 self.__log_exception(f"Error handling disconnect: {e}") 

246 

247 def __handle_subscribe(self, data, request, active_research=None): 

248 """Handle client subscription to research updates""" 

249 research_id = data.get("research_id") 

250 if research_id: 

251 # Initialize subscription set if needed 

252 with self.__lock: 

253 if research_id not in self.__socket_subscriptions: 253 ↛ 257line 253 didn't jump to line 257 because the condition on line 253 was always true

254 self.__socket_subscriptions[research_id] = set() 

255 

256 # Add this client to the subscribers 

257 self.__socket_subscriptions[research_id].add(request.sid) 

258 self.__log_info( 

259 f"Client {request.sid} subscribed to research {research_id}" 

260 ) 

261 

262 # Send current status immediately if available in active research 

263 if active_research and research_id in active_research: 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true

264 progress = active_research[research_id]["progress"] 

265 latest_log = ( 

266 active_research[research_id]["log"][-1] 

267 if active_research[research_id]["log"] 

268 else None 

269 ) 

270 

271 if latest_log: 

272 self.emit_socket_event( 

273 f"research_progress_{research_id}", 

274 { 

275 "progress": progress, 

276 "message": latest_log.get( 

277 "message", "Processing..." 

278 ), 

279 "status": ResearchStatus.IN_PROGRESS, 

280 "log_entry": latest_log, 

281 }, 

282 room=request.sid, 

283 ) 

284 

285 def __handle_socket_error(self, e): 

286 """Handle Socket.IO errors""" 

287 self.__log_exception(f"Socket.IO error: {str(e)}") 

288 # Don't propagate exceptions to avoid crashing the server 

289 return False 

290 

291 def __handle_default_error(self, e): 

292 """Handle unhandled Socket.IO errors""" 

293 self.__log_exception(f"Unhandled Socket.IO error: {str(e)}") 

294 # Don't propagate exceptions to avoid crashing the server 

295 return False 

296 

297 def run(self, host: str, port: int, debug: bool = False) -> NoReturn: 

298 """ 

299 Runs the SocketIO server. 

300 

301 Args: 

302 host: The hostname to bind the server to. 

303 port: The port number to listen on. 

304 debug: Whether to run in debug mode. Defaults to False. 

305 

306 """ 

307 # Suppress Server header to prevent version information disclosure 

308 # This must be done before starting the server because Werkzeug adds 

309 # the header at the HTTP layer, not WSGI layer 

310 try: 

311 from werkzeug.serving import WSGIRequestHandler 

312 

313 WSGIRequestHandler.version_string = lambda self: "" 

314 logger.debug("Suppressed Server header for security") 

315 except ImportError: 

316 logger.warning( 

317 "Could not suppress Server header - werkzeug not found" 

318 ) 

319 

320 logger.info(f"Starting web server on {host}:{port} (debug: {debug})") 

321 self.__socketio.run( 

322 self.__app, # Use the stored Flask app reference 

323 debug=debug, 

324 host=host, 

325 port=port, 

326 allow_unsafe_werkzeug=True, 

327 use_reloader=False, 

328 )