Coverage for src/local_deep_research/citation_handlers/base_citation_handler.py: 96%
79 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
1"""
2Base class for all citation handlers.
3"""
5from abc import ABC, abstractmethod
6from typing import Any, Callable, Dict, List, Optional, Union
8from langchain_core.documents import Document
9from loguru import logger
11from ..utilities.json_utils import get_llm_response_text
14class BaseCitationHandler(ABC):
15 """Abstract base class for citation handlers."""
17 def __init__(self, llm, settings_snapshot=None):
18 self.llm = llm
19 self.settings_snapshot = settings_snapshot or {}
20 self._fact_checking_logged = False
21 self.stream_callback: Optional[Callable[[str], None]] = None
23 def set_stream_callback(self, callback: Callable[[str], None]):
24 """Set a callback that receives each streamed LLM token."""
25 self.stream_callback = callback
27 def _invoke_with_streaming(self, prompt: str) -> str:
28 """
29 Invoke the LLM, streaming tokens through the callback if set.
31 Falls back to a single ``invoke()`` call when no callback is
32 registered or when the LLM does not support ``.stream()``.
34 Returns:
35 The complete response text.
36 """
37 if self.stream_callback and hasattr(self.llm, "stream"):
38 chunks = []
39 try:
40 for chunk in self.llm.stream(prompt):
41 text = (
42 chunk
43 if isinstance(chunk, str)
44 else getattr(chunk, "content", str(chunk))
45 )
46 if text:
47 chunks.append(text)
48 try:
49 self.stream_callback(text)
50 except Exception:
51 logger.debug(
52 "stream_callback failed"
53 ) # Non-critical: don't break synthesis
54 # Normalize the joined chunks exactly like the invoke() path
55 # below: .stream() bypasses ProcessingLLMWrapper.invoke (which
56 # is the only place <think> blocks are stripped), so without
57 # this a reasoning model's <think>…</think> would leak into the
58 # persisted answer. The live token stream is a separate concern.
59 return get_llm_response_text("".join(chunks))
60 except Exception:
61 # If any chunks already crossed the wire to the client,
62 # restarting via .invoke() would (a) double-bill the LLM
63 # and (b) cause the frontend's accumulated streamed text
64 # to diverge from the new full response — the chat bubble
65 # then shows partial chunks while the DB row carries the
66 # invoke()-result. Only fall back when nothing was emitted.
67 if chunks:
68 logger.warning(
69 "Stream errored after {} chunks; returning partial "
70 "content (no invoke() fallback to avoid double-bill "
71 "and UI/DB divergence)",
72 len(chunks),
73 )
74 return get_llm_response_text("".join(chunks))
75 logger.debug(
76 "Stream failed before any chunk; falling back to invoke()"
77 )
79 # No callback (non-chat research) or stream unavailable: delegate to
80 # the same normalization _invoke_text uses, so <think> blocks
81 # are stripped and str/object responses are handled uniformly.
82 return get_llm_response_text(self.llm.invoke(prompt))
84 def _invoke_text(self, prompt: str) -> str:
85 """Invoke the LLM and return normalized text.
87 Handles both message objects (``.content``) and raw string responses,
88 and strips ``<think>`` reasoning blocks via ``get_llm_response_text``.
89 """
90 return get_llm_response_text(self.llm.invoke(prompt))
92 def get_setting(self, key: str, default=None):
93 """Get a setting value from the snapshot."""
94 if key in self.settings_snapshot:
95 value = self.settings_snapshot[key]
96 # Extract value from dict structure if needed
97 if isinstance(value, dict) and "value" in value:
98 return value["value"]
99 return value
100 return default
102 def is_fact_checking_enabled(self) -> bool:
103 """Check if fact-checking is enabled and log the state once."""
104 enabled = self.get_setting("general.enable_fact_checking", False)
105 if not self._fact_checking_logged:
106 handler_name = type(self).__name__
107 if enabled:
108 logger.info(
109 f"[{handler_name}] Fact-checking is ENABLED — "
110 f"extra LLM call per synthesis"
111 )
112 else:
113 logger.info(f"[{handler_name}] Fact-checking is DISABLED")
114 self._fact_checking_logged = True
115 return bool(enabled)
117 def _get_output_instruction_prefix(self) -> str:
118 """
119 Get formatted output instructions from settings if present.
121 This allows users to customize output language, tone, style, and formatting
122 for research answers and reports. Instructions are prepended to prompts
123 sent to the LLM.
125 Returns:
126 str: Formatted instruction prefix if custom instructions are set,
127 empty string otherwise.
129 Examples:
130 - "Respond in Spanish with formal academic tone"
131 - "Use simple language suitable for beginners"
132 - "Be concise with bullet points"
133 """
134 output_instructions = self.get_setting(
135 "general.output_instructions", ""
136 ).strip()
138 if output_instructions:
139 return f"User-Specified Output Style: {output_instructions}\n\n"
140 return ""
142 def _create_documents(
143 self, search_results: Union[str, List[Dict]], nr_of_links: int = 0
144 ) -> List[Document]:
145 """
146 Convert search results to LangChain documents format and add index
147 to original search results.
148 """
149 documents: List[Document] = []
150 if isinstance(search_results, str):
151 return documents
153 for i, result in enumerate(search_results):
154 if isinstance(result, dict): 154 ↛ 153line 154 didn't jump to line 153 because the condition on line 154 was always true
155 # Add index to the original search result dictionary if it doesn't exist
156 # This preserves indices that were already set (e.g., for topic organization)
157 if "index" not in result:
158 result["index"] = str(i + nr_of_links + 1)
160 content = result.get("full_content", result.get("snippet", ""))
161 # Use the index from the result if it exists, otherwise calculate it
162 doc_index = int(result.get("index", i + nr_of_links + 1))
163 documents.append(
164 Document(
165 page_content=content,
166 metadata={
167 "source": result.get("link", f"source_{i + 1}"),
168 "title": result.get("title", f"Source {i + 1}"),
169 "index": doc_index,
170 },
171 )
172 )
173 return documents
175 def _format_sources(self, documents: List[Document]) -> str:
176 """Format sources with numbers for citation."""
177 sources = []
178 for doc in documents:
179 source_id = doc.metadata["index"]
180 sources.append(f"[{source_id}] {doc.page_content}")
181 return "\n\n".join(sources)
183 @abstractmethod
184 def analyze_initial(
185 self, query: str, search_results: Union[str, List[Dict]]
186 ) -> Dict[str, Any]:
187 """Process initial analysis with citations."""
188 pass
190 @abstractmethod
191 def analyze_followup(
192 self,
193 question: str,
194 search_results: Union[str, List[Dict]],
195 previous_knowledge: str,
196 nr_of_links: int,
197 ) -> Dict[str, Any]:
198 """Process follow-up analysis with citations."""
199 pass