Coverage for src / local_deep_research / domain_classifier / classifier.py: 90%
152 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
1"""Domain classifier using LLM for categorization."""
3import json
4from typing import List, Dict, Optional
5from loguru import logger
6from sqlalchemy.orm import Session
7from ..database.session_context import get_user_db_session
8from ..database.models import ResearchResource
9from .models import DomainClassification
10from ..config.llm_config import get_llm
11from ..utilities.json_utils import extract_json, get_llm_response_text
14# Predefined categories for domain classification
15DOMAIN_CATEGORIES = {
16 "Academic & Research": [
17 "University/Education",
18 "Scientific Journal",
19 "Research Institution",
20 "Academic Database",
21 ],
22 "News & Media": [
23 "General News",
24 "Tech News",
25 "Business News",
26 "Entertainment News",
27 "Local/Regional News",
28 ],
29 "Reference & Documentation": [
30 "Encyclopedia",
31 "Technical Documentation",
32 "API Documentation",
33 "Tutorial/Guide",
34 "Dictionary/Glossary",
35 ],
36 "Social & Community": [
37 "Social Network",
38 "Forum/Discussion",
39 "Q&A Platform", # Question-and-answer focused sites (StackOverflow, Quora)
40 "Blog Platform", # Structured publishing platforms (Medium, WordPress.com)
41 "Personal Blog", # Individual author blogs
42 ],
43 "Business & Commerce": [
44 "E-commerce",
45 "Corporate Website",
46 "B2B Platform",
47 "Financial Service",
48 "Marketing/Advertising",
49 ],
50 "Technology": [
51 "Software Development",
52 "Cloud Service",
53 "Open Source Project",
54 "Tech Company",
55 "Developer Tools",
56 ],
57 "Government & Organization": [
58 "Government Agency",
59 "Non-profit",
60 "International Organization",
61 "Think Tank",
62 "Industry Association",
63 ],
64 "Entertainment & Lifestyle": [
65 "Streaming Service",
66 "Gaming",
67 "Sports",
68 "Arts & Culture",
69 "Travel & Tourism",
70 ],
71 "Professional & Industry": [
72 "Healthcare",
73 "Legal",
74 "Real Estate",
75 "Manufacturing",
76 "Energy & Utilities",
77 ],
78 "Other": ["Personal Website", "Miscellaneous", "Unknown"],
79}
82class DomainClassifier:
83 """Classify domains using LLM with predefined categories."""
85 def __init__(self, username: str, settings_snapshot: dict = None):
86 """Initialize the domain classifier.
88 Args:
89 username: Username for database session
90 settings_snapshot: Settings snapshot for LLM configuration
91 """
92 self.username = username
93 self.settings_snapshot = settings_snapshot
94 self.llm = None
96 def _get_llm(self):
97 """Get or initialize LLM instance."""
98 if self.llm is None:
99 self.llm = get_llm(settings_snapshot=self.settings_snapshot)
100 return self.llm
102 def _get_domain_samples(
103 self, domain: str, session: Session, limit: int = 5
104 ) -> List[Dict]:
105 """Get sample resources from a domain.
107 Args:
108 domain: Domain to get samples for
109 session: Database session
110 limit: Maximum number of samples
112 Returns:
113 List of resource samples
114 """
115 resources = (
116 session.query(ResearchResource)
117 .filter(ResearchResource.url.like(f"%{domain}%"))
118 .limit(limit)
119 .all()
120 )
122 samples = []
123 for resource in resources: 123 ↛ 124line 123 didn't jump to line 124 because the loop on line 123 never started
124 samples.append(
125 {
126 "title": resource.title or "Untitled",
127 "url": resource.url,
128 "preview": resource.content_preview[:200]
129 if resource.content_preview
130 else None,
131 }
132 )
134 return samples
136 def _build_classification_prompt(
137 self, domain: str, samples: List[Dict]
138 ) -> str:
139 """Build prompt for LLM classification.
141 This method uses actual content samples (titles, previews) from the domain
142 rather than relying solely on domain name patterns, providing more
143 accurate classification based on actual site content.
145 Args:
146 domain: Domain to classify
147 samples: Sample resources from the domain (titles, URLs, content previews)
149 Returns:
150 Formatted prompt string
151 """
152 # Format categories for prompt
153 categories_text = []
154 for main_cat, subcats in DOMAIN_CATEGORIES.items():
155 subcats_text = ", ".join(subcats)
156 categories_text.append(f"{main_cat}: {subcats_text}")
158 # Format samples
159 samples_text = []
160 for i, sample in enumerate(samples[:5], 1):
161 samples_text.append(f"{i}. Title: {sample['title']}")
162 if sample.get("preview"): 162 ↛ 160line 162 didn't jump to line 160 because the condition on line 162 was always true
163 samples_text.append(f" Preview: {sample['preview'][:100]}...")
165 prompt = f"""Classify the following domain into one of the predefined categories.
167Domain: {domain}
169Sample content from this domain:
170{chr(10).join(samples_text) if samples_text else "No samples available"}
172Available Categories:
173{chr(10).join(categories_text)}
175Respond with a JSON object containing:
176- "category": The main category (e.g., "News & Media")
177- "subcategory": The specific subcategory (e.g., "Tech News")
178- "confidence": A confidence score between 0 and 1
179- "reasoning": A brief explanation (max 100 words) of why this classification was chosen
181Focus on accuracy. If uncertain, use "Other" category with "Unknown" subcategory.
183JSON Response:"""
185 return prompt
187 def classify_domain(
188 self, domain: str, force_update: bool = False
189 ) -> Optional[DomainClassification]:
190 """Classify a single domain using LLM.
192 Args:
193 domain: Domain to classify
194 force_update: If True, reclassify even if already exists
196 Returns:
197 DomainClassification object or None if failed
198 """
199 try:
200 with get_user_db_session(self.username) as session:
201 # Check if already classified
202 existing = (
203 session.query(DomainClassification)
204 .filter_by(domain=domain)
205 .first()
206 )
208 if existing and not force_update:
209 logger.info(
210 f"Domain {domain} already classified as {existing.category}"
211 )
212 return existing
214 # Get sample resources
215 samples = self._get_domain_samples(domain, session)
217 # Build prompt and get classification
218 prompt = self._build_classification_prompt(domain, samples)
219 llm = self._get_llm()
221 response = llm.invoke(prompt)
222 response_text = get_llm_response_text(response)
224 result = extract_json(response_text, expected_type=dict)
225 if result is None:
226 raise ValueError("Could not parse JSON from LLM response")
228 # Create or update classification
229 if existing: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true
230 existing.category = result.get("category", "Other")
231 existing.subcategory = result.get("subcategory", "Unknown")
232 existing.confidence = float(result.get("confidence", 0.5))
233 existing.reasoning = result.get("reasoning", "")
234 existing.sample_titles = json.dumps(
235 [s["title"] for s in samples]
236 )
237 existing.sample_count = len(samples)
238 classification = existing
239 else:
240 classification = DomainClassification(
241 domain=domain,
242 category=result.get("category", "Other"),
243 subcategory=result.get("subcategory", "Unknown"),
244 confidence=float(result.get("confidence", 0.5)),
245 reasoning=result.get("reasoning", ""),
246 sample_titles=json.dumps([s["title"] for s in samples]),
247 sample_count=len(samples),
248 )
249 session.add(classification)
251 session.commit()
252 logger.info(
253 f"Classified {domain} as {classification.category}/{classification.subcategory} with confidence {classification.confidence}"
254 )
255 return classification
257 except Exception:
258 logger.exception(f"Error classifying domain {domain}")
259 return None
261 def classify_all_domains(
262 self, force_update: bool = False, progress_callback=None
263 ) -> Dict:
264 """Classify all unique domains in the database.
266 Args:
267 force_update: If True, reclassify all domains
268 progress_callback: Optional callback function for progress updates
270 Returns:
271 Dictionary with classification results
272 """
273 results = {
274 "total": 0,
275 "classified": 0,
276 "failed": 0,
277 "skipped": 0,
278 "domains": [],
279 }
281 try:
282 with get_user_db_session(self.username) as session:
283 # Get all unique domains
284 from urllib.parse import urlparse
286 resources = session.query(ResearchResource.url).distinct().all()
287 domains = set()
289 for (url,) in resources:
290 if url:
291 try:
292 parsed = urlparse(url)
293 domain = parsed.netloc.lower()
294 if domain.startswith("www."):
295 domain = domain[4:]
296 if domain:
297 domains.add(domain)
298 except (ValueError, AttributeError):
299 continue
301 results["total"] = len(domains)
302 logger.info(
303 f"Found {results['total']} unique domains to process"
304 )
306 # Classify each domain ONE BY ONE
307 for i, domain in enumerate(sorted(domains), 1):
308 logger.info(
309 f"Processing domain {i}/{results['total']}: {domain}"
310 )
312 if progress_callback:
313 progress_callback(
314 {
315 "current": i,
316 "total": results["total"],
317 "domain": domain,
318 "percentage": (i / results["total"]) * 100,
319 }
320 )
322 try:
323 # Check if already classified
324 if not force_update:
325 existing = (
326 session.query(DomainClassification)
327 .filter_by(domain=domain)
328 .first()
329 )
330 if existing: 330 ↛ 346line 330 didn't jump to line 346 because the condition on line 330 was always true
331 results["skipped"] += 1
332 results["domains"].append(
333 {
334 "domain": domain,
335 "status": "skipped",
336 "category": existing.category,
337 "subcategory": existing.subcategory,
338 }
339 )
340 logger.info(
341 f"Domain {domain} already classified, skipping"
342 )
343 continue
345 # Classify this single domain
346 classification = self.classify_domain(
347 domain, force_update
348 )
350 if classification:
351 results["classified"] += 1
352 results["domains"].append(
353 {
354 "domain": domain,
355 "status": "classified",
356 "category": classification.category,
357 "subcategory": classification.subcategory,
358 "confidence": classification.confidence,
359 }
360 )
361 logger.info(
362 f"Successfully classified {domain} as {classification.category}"
363 )
364 else:
365 results["failed"] += 1
366 results["domains"].append(
367 {"domain": domain, "status": "failed"}
368 )
369 logger.warning(
370 f"Failed to classify domain {domain}"
371 )
373 except Exception:
374 logger.exception(f"Error classifying domain {domain}")
375 results["failed"] += 1
376 results["domains"].append(
377 {
378 "domain": domain,
379 "status": "failed",
380 "error": "Classification failed",
381 }
382 )
384 logger.info(
385 f"Classification complete: {results['classified']} classified, {results['skipped']} skipped, {results['failed']} failed"
386 )
387 return results
389 except Exception:
390 logger.exception("Error in classify_all_domains")
391 results["error"] = "Classification failed"
392 return results
394 def get_classification(self, domain: str) -> Optional[DomainClassification]:
395 """Get existing classification for a domain.
397 Args:
398 domain: Domain to look up
400 Returns:
401 DomainClassification object or None if not found
402 """
403 try:
404 with get_user_db_session(self.username) as session:
405 return (
406 session.query(DomainClassification)
407 .filter_by(domain=domain)
408 .first()
409 )
410 except Exception:
411 logger.exception(f"Error getting classification for {domain}")
412 return None
414 def get_all_classifications(self) -> List[DomainClassification]:
415 """Get all domain classifications.
417 Returns:
418 List of all DomainClassification objects
419 """
420 try:
421 with get_user_db_session(self.username) as session:
422 return (
423 session.query(DomainClassification)
424 .order_by(
425 DomainClassification.category,
426 DomainClassification.domain,
427 )
428 .all()
429 )
430 except Exception:
431 logger.exception("Error getting all classifications")
432 return []
434 def get_categories_summary(self) -> Dict:
435 """Get summary of domain classifications by category.
437 Returns:
438 Dictionary with category counts and domains
439 """
440 try:
441 with get_user_db_session(self.username) as session:
442 classifications = session.query(DomainClassification).all()
444 summary = {}
445 for classification in classifications:
446 cat = classification.category
447 if cat not in summary:
448 summary[cat] = {
449 "count": 0,
450 "domains": [],
451 "subcategories": {},
452 }
454 summary[cat]["count"] += 1
455 summary[cat]["domains"].append(
456 {
457 "domain": classification.domain,
458 "subcategory": classification.subcategory,
459 "confidence": classification.confidence,
460 }
461 )
463 subcat = classification.subcategory
464 if subcat:
465 if subcat not in summary[cat]["subcategories"]: 465 ↛ 467line 465 didn't jump to line 467 because the condition on line 465 was always true
466 summary[cat]["subcategories"][subcat] = 0
467 summary[cat]["subcategories"][subcat] += 1
469 return summary
471 except Exception:
472 logger.exception("Error getting categories summary")
473 return {}