Coverage for src/local_deep_research/citation_handlers/base_citation_handler.py: 96%

79 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 23:15 +0000

1""" 

2Base class for all citation handlers. 

3""" 

4 

5from abc import ABC, abstractmethod 

6from typing import Any, Callable, Dict, List, Optional, Union 

7 

8from langchain_core.documents import Document 

9from loguru import logger 

10 

11from ..utilities.json_utils import get_llm_response_text 

12 

13 

14class BaseCitationHandler(ABC): 

15 """Abstract base class for citation handlers.""" 

16 

17 def __init__(self, llm, settings_snapshot=None): 

18 self.llm = llm 

19 self.settings_snapshot = settings_snapshot or {} 

20 self._fact_checking_logged = False 

21 self.stream_callback: Optional[Callable[[str], None]] = None 

22 

23 def set_stream_callback(self, callback: Callable[[str], None]): 

24 """Set a callback that receives each streamed LLM token.""" 

25 self.stream_callback = callback 

26 

27 def _invoke_with_streaming(self, prompt: str) -> str: 

28 """ 

29 Invoke the LLM, streaming tokens through the callback if set. 

30 

31 Falls back to a single ``invoke()`` call when no callback is 

32 registered or when the LLM does not support ``.stream()``. 

33 

34 Returns: 

35 The complete response text. 

36 """ 

37 if self.stream_callback and hasattr(self.llm, "stream"): 

38 chunks = [] 

39 try: 

40 for chunk in self.llm.stream(prompt): 

41 text = ( 

42 chunk 

43 if isinstance(chunk, str) 

44 else getattr(chunk, "content", str(chunk)) 

45 ) 

46 if text: 

47 chunks.append(text) 

48 try: 

49 self.stream_callback(text) 

50 except Exception: 

51 logger.debug( 

52 "stream_callback failed" 

53 ) # Non-critical: don't break synthesis 

54 # Normalize the joined chunks exactly like the invoke() path 

55 # below: .stream() bypasses ProcessingLLMWrapper.invoke (which 

56 # is the only place <think> blocks are stripped), so without 

57 # this a reasoning model's <think>…</think> would leak into the 

58 # persisted answer. The live token stream is a separate concern. 

59 return get_llm_response_text("".join(chunks)) 

60 except Exception: 

61 # If any chunks already crossed the wire to the client, 

62 # restarting via .invoke() would (a) double-bill the LLM 

63 # and (b) cause the frontend's accumulated streamed text 

64 # to diverge from the new full response — the chat bubble 

65 # then shows partial chunks while the DB row carries the 

66 # invoke()-result. Only fall back when nothing was emitted. 

67 if chunks: 

68 logger.warning( 

69 "Stream errored after {} chunks; returning partial " 

70 "content (no invoke() fallback to avoid double-bill " 

71 "and UI/DB divergence)", 

72 len(chunks), 

73 ) 

74 return get_llm_response_text("".join(chunks)) 

75 logger.debug( 

76 "Stream failed before any chunk; falling back to invoke()" 

77 ) 

78 

79 # No callback (non-chat research) or stream unavailable: delegate to 

80 # the same normalization _invoke_text uses, so <think> blocks 

81 # are stripped and str/object responses are handled uniformly. 

82 return get_llm_response_text(self.llm.invoke(prompt)) 

83 

84 def _invoke_text(self, prompt: str) -> str: 

85 """Invoke the LLM and return normalized text. 

86 

87 Handles both message objects (``.content``) and raw string responses, 

88 and strips ``<think>`` reasoning blocks via ``get_llm_response_text``. 

89 """ 

90 return get_llm_response_text(self.llm.invoke(prompt)) 

91 

92 def get_setting(self, key: str, default=None): 

93 """Get a setting value from the snapshot.""" 

94 if key in self.settings_snapshot: 

95 value = self.settings_snapshot[key] 

96 # Extract value from dict structure if needed 

97 if isinstance(value, dict) and "value" in value: 

98 return value["value"] 

99 return value 

100 return default 

101 

102 def is_fact_checking_enabled(self) -> bool: 

103 """Check if fact-checking is enabled and log the state once.""" 

104 enabled = self.get_setting("general.enable_fact_checking", False) 

105 if not self._fact_checking_logged: 

106 handler_name = type(self).__name__ 

107 if enabled: 

108 logger.info( 

109 f"[{handler_name}] Fact-checking is ENABLED — " 

110 f"extra LLM call per synthesis" 

111 ) 

112 else: 

113 logger.info(f"[{handler_name}] Fact-checking is DISABLED") 

114 self._fact_checking_logged = True 

115 return bool(enabled) 

116 

117 def _get_output_instruction_prefix(self) -> str: 

118 """ 

119 Get formatted output instructions from settings if present. 

120 

121 This allows users to customize output language, tone, style, and formatting 

122 for research answers and reports. Instructions are prepended to prompts 

123 sent to the LLM. 

124 

125 Returns: 

126 str: Formatted instruction prefix if custom instructions are set, 

127 empty string otherwise. 

128 

129 Examples: 

130 - "Respond in Spanish with formal academic tone" 

131 - "Use simple language suitable for beginners" 

132 - "Be concise with bullet points" 

133 """ 

134 output_instructions = self.get_setting( 

135 "general.output_instructions", "" 

136 ).strip() 

137 

138 if output_instructions: 

139 return f"User-Specified Output Style: {output_instructions}\n\n" 

140 return "" 

141 

142 def _create_documents( 

143 self, search_results: Union[str, List[Dict]], nr_of_links: int = 0 

144 ) -> List[Document]: 

145 """ 

146 Convert search results to LangChain documents format and add index 

147 to original search results. 

148 """ 

149 documents: List[Document] = [] 

150 if isinstance(search_results, str): 

151 return documents 

152 

153 for i, result in enumerate(search_results): 

154 if isinstance(result, dict): 154 ↛ 153line 154 didn't jump to line 153 because the condition on line 154 was always true

155 # Add index to the original search result dictionary if it doesn't exist 

156 # This preserves indices that were already set (e.g., for topic organization) 

157 if "index" not in result: 

158 result["index"] = str(i + nr_of_links + 1) 

159 

160 content = result.get("full_content", result.get("snippet", "")) 

161 # Use the index from the result if it exists, otherwise calculate it 

162 doc_index = int(result.get("index", i + nr_of_links + 1)) 

163 documents.append( 

164 Document( 

165 page_content=content, 

166 metadata={ 

167 "source": result.get("link", f"source_{i + 1}"), 

168 "title": result.get("title", f"Source {i + 1}"), 

169 "index": doc_index, 

170 }, 

171 ) 

172 ) 

173 return documents 

174 

175 def _format_sources(self, documents: List[Document]) -> str: 

176 """Format sources with numbers for citation.""" 

177 sources = [] 

178 for doc in documents: 

179 source_id = doc.metadata["index"] 

180 sources.append(f"[{source_id}] {doc.page_content}") 

181 return "\n\n".join(sources) 

182 

183 @abstractmethod 

184 def analyze_initial( 

185 self, query: str, search_results: Union[str, List[Dict]] 

186 ) -> Dict[str, Any]: 

187 """Process initial analysis with citations.""" 

188 pass 

189 

190 @abstractmethod 

191 def analyze_followup( 

192 self, 

193 question: str, 

194 search_results: Union[str, List[Dict]], 

195 previous_knowledge: str, 

196 nr_of_links: int, 

197 ) -> Dict[str, Any]: 

198 """Process follow-up analysis with citations.""" 

199 pass