Coverage for src / local_deep_research / domain_classifier / classifier.py: 93%
155 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1"""Domain classifier using LLM for categorization."""
3import json
4from typing import Dict, List, Optional
6from loguru import logger
7from sqlalchemy.orm import Session
9from ..config.llm_config import get_llm
10from ..database.models import ResearchResource
11from ..database.session_context import get_user_db_session
12from ..utilities.json_utils import extract_json, get_llm_response_text
13from .models import DomainClassification
16# Predefined categories for domain classification
17DOMAIN_CATEGORIES = {
18 "Academic & Research": [
19 "University/Education",
20 "Scientific Journal",
21 "Research Institution",
22 "Academic Database",
23 ],
24 "News & Media": [
25 "General News",
26 "Tech News",
27 "Business News",
28 "Entertainment News",
29 "Local/Regional News",
30 ],
31 "Reference & Documentation": [
32 "Encyclopedia",
33 "Technical Documentation",
34 "API Documentation",
35 "Tutorial/Guide",
36 "Dictionary/Glossary",
37 ],
38 "Social & Community": [
39 "Social Network",
40 "Forum/Discussion",
41 "Q&A Platform", # Question-and-answer focused sites (StackOverflow, Quora)
42 "Blog Platform", # Structured publishing platforms (Medium, WordPress.com)
43 "Personal Blog", # Individual author blogs
44 ],
45 "Business & Commerce": [
46 "E-commerce",
47 "Corporate Website",
48 "B2B Platform",
49 "Financial Service",
50 "Marketing/Advertising",
51 ],
52 "Technology": [
53 "Software Development",
54 "Cloud Service",
55 "Open Source Project",
56 "Tech Company",
57 "Developer Tools",
58 ],
59 "Government & Organization": [
60 "Government Agency",
61 "Non-profit",
62 "International Organization",
63 "Think Tank",
64 "Industry Association",
65 ],
66 "Entertainment & Lifestyle": [
67 "Streaming Service",
68 "Gaming",
69 "Sports",
70 "Arts & Culture",
71 "Travel & Tourism",
72 ],
73 "Professional & Industry": [
74 "Healthcare",
75 "Legal",
76 "Real Estate",
77 "Manufacturing",
78 "Energy & Utilities",
79 ],
80 "Other": ["Personal Website", "Miscellaneous", "Unknown"],
81}
84class DomainClassifier:
85 """Classify domains using LLM with predefined categories."""
87 def __init__(self, username: str, settings_snapshot: dict = None):
88 """Initialize the domain classifier.
90 Args:
91 username: Username for database session
92 settings_snapshot: Settings snapshot for LLM configuration
93 """
94 self.username = username
95 self.settings_snapshot = settings_snapshot
96 self.llm = None
98 def _get_llm(self):
99 """Get or initialize LLM instance."""
100 if self.llm is None:
101 self.llm = get_llm(settings_snapshot=self.settings_snapshot)
102 return self.llm
104 def close(self) -> None:
105 """Close the LLM client if one was created."""
106 from ..utilities.resource_utils import safe_close
108 safe_close(self.llm, "classifier LLM")
110 def _get_domain_samples(
111 self, domain: str, session: Session, limit: int = 5
112 ) -> List[Dict]:
113 """Get sample resources from a domain.
115 Args:
116 domain: Domain to get samples for
117 session: Database session
118 limit: Maximum number of samples
120 Returns:
121 List of resource samples
122 """
123 resources = (
124 session.query(ResearchResource)
125 .filter(ResearchResource.url.like(f"%{domain}%"))
126 .limit(limit)
127 .all()
128 )
130 samples = []
131 for resource in resources:
132 samples.append(
133 {
134 "title": resource.title or "Untitled",
135 "url": resource.url,
136 "preview": resource.content_preview[:200]
137 if resource.content_preview
138 else None,
139 }
140 )
142 return samples
144 def _build_classification_prompt(
145 self, domain: str, samples: List[Dict]
146 ) -> str:
147 """Build prompt for LLM classification.
149 This method uses actual content samples (titles, previews) from the domain
150 rather than relying solely on domain name patterns, providing more
151 accurate classification based on actual site content.
153 Args:
154 domain: Domain to classify
155 samples: Sample resources from the domain (titles, URLs, content previews)
157 Returns:
158 Formatted prompt string
159 """
160 # Format categories for prompt
161 categories_text = []
162 for main_cat, subcats in DOMAIN_CATEGORIES.items():
163 subcats_text = ", ".join(subcats)
164 categories_text.append(f"{main_cat}: {subcats_text}")
166 # Format samples
167 samples_text = []
168 for i, sample in enumerate(samples[:5], 1):
169 samples_text.append(f"{i}. Title: {sample['title']}")
170 if sample.get("preview"):
171 samples_text.append(f" Preview: {sample['preview'][:100]}...")
173 return f"""Classify the following domain into one of the predefined categories.
175Domain: {domain}
177Sample content from this domain:
178{chr(10).join(samples_text) if samples_text else "No samples available"}
180Available Categories:
181{chr(10).join(categories_text)}
183Respond with a JSON object containing:
184- "category": The main category (e.g., "News & Media")
185- "subcategory": The specific subcategory (e.g., "Tech News")
186- "confidence": A confidence score between 0 and 1
187- "reasoning": A brief explanation (max 100 words) of why this classification was chosen
189Focus on accuracy. If uncertain, use "Other" category with "Unknown" subcategory.
191JSON Response:"""
193 def classify_domain(
194 self, domain: str, force_update: bool = False
195 ) -> Optional[DomainClassification]:
196 """Classify a single domain using LLM.
198 Args:
199 domain: Domain to classify
200 force_update: If True, reclassify even if already exists
202 Returns:
203 DomainClassification object or None if failed
204 """
205 try:
206 with get_user_db_session(self.username) as session:
207 # Check if already classified
208 existing = (
209 session.query(DomainClassification)
210 .filter_by(domain=domain)
211 .first()
212 )
214 if existing and not force_update:
215 logger.info(
216 f"Domain {domain} already classified as {existing.category}"
217 )
218 return existing
220 # Get sample resources
221 samples = self._get_domain_samples(domain, session)
223 # Build prompt and get classification
224 prompt = self._build_classification_prompt(domain, samples)
225 llm = self._get_llm()
227 response = llm.invoke(prompt)
228 response_text = get_llm_response_text(response)
230 result = extract_json(response_text, expected_type=dict)
231 if result is None:
232 raise ValueError("Could not parse JSON from LLM response") # noqa: TRY301 — inside db session; except logs and returns None
234 # Create or update classification
235 if existing: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true
236 existing.category = result.get("category", "Other")
237 existing.subcategory = result.get("subcategory", "Unknown")
238 existing.confidence = float(result.get("confidence", 0.5))
239 existing.reasoning = result.get("reasoning", "")
240 existing.sample_titles = json.dumps(
241 [s["title"] for s in samples]
242 )
243 existing.sample_count = len(samples)
244 classification = existing
245 else:
246 classification = DomainClassification(
247 domain=domain,
248 category=result.get("category", "Other"),
249 subcategory=result.get("subcategory", "Unknown"),
250 confidence=float(result.get("confidence", 0.5)),
251 reasoning=result.get("reasoning", ""),
252 sample_titles=json.dumps([s["title"] for s in samples]),
253 sample_count=len(samples),
254 )
255 session.add(classification)
257 session.commit()
258 logger.info(
259 f"Classified {domain} as {classification.category}/{classification.subcategory} with confidence {classification.confidence}"
260 )
261 return classification
263 except Exception:
264 logger.exception(f"Error classifying domain {domain}")
265 return None
267 def classify_all_domains(
268 self, force_update: bool = False, progress_callback=None
269 ) -> Dict:
270 """Classify all unique domains in the database.
272 Args:
273 force_update: If True, reclassify all domains
274 progress_callback: Optional callback function for progress updates
276 Returns:
277 Dictionary with classification results
278 """
279 results = {
280 "total": 0,
281 "classified": 0,
282 "failed": 0,
283 "skipped": 0,
284 "domains": [],
285 }
287 try:
288 with get_user_db_session(self.username) as session:
289 # Get all unique domains
290 from urllib.parse import urlparse
292 resources = session.query(ResearchResource.url).distinct().all()
293 domains = set()
295 for (url,) in resources:
296 if url:
297 try:
298 parsed = urlparse(url)
299 domain = parsed.netloc.lower()
300 if domain.startswith("www."):
301 domain = domain[4:]
302 if domain:
303 domains.add(domain)
304 except (ValueError, AttributeError):
305 logger.debug("Skipping malformed URL")
306 continue
308 results["total"] = len(domains)
309 logger.info(
310 f"Found {results['total']} unique domains to process"
311 )
313 # Classify each domain ONE BY ONE
314 for i, domain in enumerate(sorted(domains), 1):
315 logger.info(
316 f"Processing domain {i}/{results['total']}: {domain}"
317 )
319 if progress_callback:
320 progress_callback(
321 {
322 "current": i,
323 "total": results["total"],
324 "domain": domain,
325 "percentage": (i / results["total"]) * 100,
326 }
327 )
329 try:
330 # Check if already classified
331 if not force_update:
332 existing = (
333 session.query(DomainClassification)
334 .filter_by(domain=domain)
335 .first()
336 )
337 if existing: 337 ↛ 353line 337 didn't jump to line 353 because the condition on line 337 was always true
338 results["skipped"] += 1
339 results["domains"].append(
340 {
341 "domain": domain,
342 "status": "skipped",
343 "category": existing.category,
344 "subcategory": existing.subcategory,
345 }
346 )
347 logger.info(
348 f"Domain {domain} already classified, skipping"
349 )
350 continue
352 # Classify this single domain
353 classification = self.classify_domain(
354 domain, force_update
355 )
357 if classification:
358 results["classified"] += 1
359 results["domains"].append(
360 {
361 "domain": domain,
362 "status": "classified",
363 "category": classification.category,
364 "subcategory": classification.subcategory,
365 "confidence": classification.confidence,
366 }
367 )
368 logger.info(
369 f"Successfully classified {domain} as {classification.category}"
370 )
371 else:
372 results["failed"] += 1
373 results["domains"].append(
374 {"domain": domain, "status": "failed"}
375 )
376 logger.warning(
377 f"Failed to classify domain {domain}"
378 )
380 except Exception:
381 logger.exception(f"Error classifying domain {domain}")
382 results["failed"] += 1
383 results["domains"].append(
384 {
385 "domain": domain,
386 "status": "failed",
387 "error": "Classification failed",
388 }
389 )
391 logger.info(
392 f"Classification complete: {results['classified']} classified, {results['skipped']} skipped, {results['failed']} failed"
393 )
394 return results
396 except Exception:
397 logger.exception("Error in classify_all_domains")
398 results["error"] = "Classification failed"
399 return results
401 def get_classification(self, domain: str) -> Optional[DomainClassification]:
402 """Get existing classification for a domain.
404 Args:
405 domain: Domain to look up
407 Returns:
408 DomainClassification object or None if not found
409 """
410 try:
411 with get_user_db_session(self.username) as session:
412 return (
413 session.query(DomainClassification)
414 .filter_by(domain=domain)
415 .first()
416 )
417 except Exception:
418 logger.exception(f"Error getting classification for {domain}")
419 return None
421 def get_all_classifications(self) -> List[DomainClassification]:
422 """Get all domain classifications.
424 Returns:
425 List of all DomainClassification objects
426 """
427 try:
428 with get_user_db_session(self.username) as session:
429 return (
430 session.query(DomainClassification)
431 .order_by(
432 DomainClassification.category,
433 DomainClassification.domain,
434 )
435 .all()
436 )
437 except Exception:
438 logger.exception("Error getting all classifications")
439 return []
441 def get_categories_summary(self) -> Dict:
442 """Get summary of domain classifications by category.
444 Returns:
445 Dictionary with category counts and domains
446 """
447 try:
448 with get_user_db_session(self.username) as session:
449 classifications = session.query(DomainClassification).all()
451 summary = {}
452 for classification in classifications:
453 cat = classification.category
454 if cat not in summary:
455 summary[cat] = {
456 "count": 0,
457 "domains": [],
458 "subcategories": {},
459 }
461 summary[cat]["count"] += 1
462 summary[cat]["domains"].append(
463 {
464 "domain": classification.domain,
465 "subcategory": classification.subcategory,
466 "confidence": classification.confidence,
467 }
468 )
470 subcat = classification.subcategory
471 if subcat:
472 if subcat not in summary[cat]["subcategories"]: 472 ↛ 474line 472 didn't jump to line 474 because the condition on line 472 was always true
473 summary[cat]["subcategories"][subcat] = 0
474 summary[cat]["subcategories"][subcat] += 1
476 return summary
478 except Exception:
479 logger.exception("Error getting categories summary")
480 return {}