Coverage for src / local_deep_research / domain_classifier / classifier.py: 44%
158 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
1"""Domain classifier using LLM for categorization."""
3import json
4from typing import List, Dict, Optional
5from loguru import logger
6from sqlalchemy.orm import Session
7from ..database.session_context import get_user_db_session
8from ..database.models import ResearchResource
9from .models import DomainClassification
10from ..config.llm_config import get_llm
13# Predefined categories for domain classification
14DOMAIN_CATEGORIES = {
15 "Academic & Research": [
16 "University/Education",
17 "Scientific Journal",
18 "Research Institution",
19 "Academic Database",
20 ],
21 "News & Media": [
22 "General News",
23 "Tech News",
24 "Business News",
25 "Entertainment News",
26 "Local/Regional News",
27 ],
28 "Reference & Documentation": [
29 "Encyclopedia",
30 "Technical Documentation",
31 "API Documentation",
32 "Tutorial/Guide",
33 "Dictionary/Glossary",
34 ],
35 "Social & Community": [
36 "Social Network",
37 "Forum/Discussion",
38 "Q&A Platform", # Question-and-answer focused sites (StackOverflow, Quora)
39 "Blog Platform", # Structured publishing platforms (Medium, WordPress.com)
40 "Personal Blog", # Individual author blogs
41 ],
42 "Business & Commerce": [
43 "E-commerce",
44 "Corporate Website",
45 "B2B Platform",
46 "Financial Service",
47 "Marketing/Advertising",
48 ],
49 "Technology": [
50 "Software Development",
51 "Cloud Service",
52 "Open Source Project",
53 "Tech Company",
54 "Developer Tools",
55 ],
56 "Government & Organization": [
57 "Government Agency",
58 "Non-profit",
59 "International Organization",
60 "Think Tank",
61 "Industry Association",
62 ],
63 "Entertainment & Lifestyle": [
64 "Streaming Service",
65 "Gaming",
66 "Sports",
67 "Arts & Culture",
68 "Travel & Tourism",
69 ],
70 "Professional & Industry": [
71 "Healthcare",
72 "Legal",
73 "Real Estate",
74 "Manufacturing",
75 "Energy & Utilities",
76 ],
77 "Other": ["Personal Website", "Miscellaneous", "Unknown"],
78}
81class DomainClassifier:
82 """Classify domains using LLM with predefined categories."""
84 def __init__(self, username: str, settings_snapshot: dict = None):
85 """Initialize the domain classifier.
87 Args:
88 username: Username for database session
89 settings_snapshot: Settings snapshot for LLM configuration
90 """
91 self.username = username
92 self.settings_snapshot = settings_snapshot
93 self.llm = None
95 def _get_llm(self):
96 """Get or initialize LLM instance."""
97 if self.llm is None:
98 self.llm = get_llm(settings_snapshot=self.settings_snapshot)
99 return self.llm
101 def _get_domain_samples(
102 self, domain: str, session: Session, limit: int = 5
103 ) -> List[Dict]:
104 """Get sample resources from a domain.
106 Args:
107 domain: Domain to get samples for
108 session: Database session
109 limit: Maximum number of samples
111 Returns:
112 List of resource samples
113 """
114 resources = (
115 session.query(ResearchResource)
116 .filter(ResearchResource.url.like(f"%{domain}%"))
117 .limit(limit)
118 .all()
119 )
121 samples = []
122 for resource in resources:
123 samples.append(
124 {
125 "title": resource.title or "Untitled",
126 "url": resource.url,
127 "preview": resource.content_preview[:200]
128 if resource.content_preview
129 else None,
130 }
131 )
133 return samples
135 def _build_classification_prompt(
136 self, domain: str, samples: List[Dict]
137 ) -> str:
138 """Build prompt for LLM classification.
140 This method uses actual content samples (titles, previews) from the domain
141 rather than relying solely on domain name patterns, providing more
142 accurate classification based on actual site content.
144 Args:
145 domain: Domain to classify
146 samples: Sample resources from the domain (titles, URLs, content previews)
148 Returns:
149 Formatted prompt string
150 """
151 # Format categories for prompt
152 categories_text = []
153 for main_cat, subcats in DOMAIN_CATEGORIES.items():
154 subcats_text = ", ".join(subcats)
155 categories_text.append(f"{main_cat}: {subcats_text}")
157 # Format samples
158 samples_text = []
159 for i, sample in enumerate(samples[:5], 1):
160 samples_text.append(f"{i}. Title: {sample['title']}")
161 if sample.get("preview"): 161 ↛ 159line 161 didn't jump to line 159 because the condition on line 161 was always true
162 samples_text.append(f" Preview: {sample['preview'][:100]}...")
164 prompt = f"""Classify the following domain into one of the predefined categories.
166Domain: {domain}
168Sample content from this domain:
169{chr(10).join(samples_text) if samples_text else "No samples available"}
171Available Categories:
172{chr(10).join(categories_text)}
174Respond with a JSON object containing:
175- "category": The main category (e.g., "News & Media")
176- "subcategory": The specific subcategory (e.g., "Tech News")
177- "confidence": A confidence score between 0 and 1
178- "reasoning": A brief explanation (max 100 words) of why this classification was chosen
180Focus on accuracy. If uncertain, use "Other" category with "Unknown" subcategory.
182JSON Response:"""
184 return prompt
186 def classify_domain(
187 self, domain: str, force_update: bool = False
188 ) -> Optional[DomainClassification]:
189 """Classify a single domain using LLM.
191 Args:
192 domain: Domain to classify
193 force_update: If True, reclassify even if already exists
195 Returns:
196 DomainClassification object or None if failed
197 """
198 try:
199 with get_user_db_session(self.username) as session:
200 # Check if already classified
201 existing = (
202 session.query(DomainClassification)
203 .filter_by(domain=domain)
204 .first()
205 )
207 if existing and not force_update: 207 ↛ 214line 207 didn't jump to line 214 because the condition on line 207 was always true
208 logger.info(
209 f"Domain {domain} already classified as {existing.category}"
210 )
211 return existing
213 # Get sample resources
214 samples = self._get_domain_samples(domain, session)
216 # Build prompt and get classification
217 prompt = self._build_classification_prompt(domain, samples)
218 llm = self._get_llm()
220 response = llm.invoke(prompt)
222 # Extract content from response
223 if hasattr(response, "content"):
224 response_text = response.content
225 else:
226 response_text = str(response)
228 # Parse JSON response
229 try:
230 result = json.loads(response_text)
231 except json.JSONDecodeError:
232 # Try to extract JSON from response
233 import re
235 json_match = re.search(r"\{.*\}", response_text, re.DOTALL)
236 if json_match:
237 result = json.loads(json_match.group())
238 else:
239 raise ValueError(
240 "Could not parse JSON from LLM response"
241 )
243 # Create or update classification
244 if existing:
245 existing.category = result.get("category", "Other")
246 existing.subcategory = result.get("subcategory", "Unknown")
247 existing.confidence = float(result.get("confidence", 0.5))
248 existing.reasoning = result.get("reasoning", "")
249 existing.sample_titles = json.dumps(
250 [s["title"] for s in samples]
251 )
252 existing.sample_count = len(samples)
253 classification = existing
254 else:
255 classification = DomainClassification(
256 domain=domain,
257 category=result.get("category", "Other"),
258 subcategory=result.get("subcategory", "Unknown"),
259 confidence=float(result.get("confidence", 0.5)),
260 reasoning=result.get("reasoning", ""),
261 sample_titles=json.dumps([s["title"] for s in samples]),
262 sample_count=len(samples),
263 )
264 session.add(classification)
266 session.commit()
267 logger.info(
268 f"Classified {domain} as {classification.category}/{classification.subcategory} with confidence {classification.confidence}"
269 )
270 return classification
272 except Exception:
273 logger.exception(f"Error classifying domain {domain}")
274 return None
276 def classify_all_domains(
277 self, force_update: bool = False, progress_callback=None
278 ) -> Dict:
279 """Classify all unique domains in the database.
281 Args:
282 force_update: If True, reclassify all domains
283 progress_callback: Optional callback function for progress updates
285 Returns:
286 Dictionary with classification results
287 """
288 results = {
289 "total": 0,
290 "classified": 0,
291 "failed": 0,
292 "skipped": 0,
293 "domains": [],
294 }
296 try:
297 with get_user_db_session(self.username) as session:
298 # Get all unique domains
299 from urllib.parse import urlparse
301 resources = session.query(ResearchResource.url).distinct().all()
302 domains = set()
304 for (url,) in resources:
305 if url:
306 try:
307 parsed = urlparse(url)
308 domain = parsed.netloc.lower()
309 if domain.startswith("www."):
310 domain = domain[4:]
311 if domain:
312 domains.add(domain)
313 except:
314 continue
316 results["total"] = len(domains)
317 logger.info(
318 f"Found {results['total']} unique domains to process"
319 )
321 # Classify each domain ONE BY ONE
322 for i, domain in enumerate(sorted(domains), 1):
323 logger.info(
324 f"Processing domain {i}/{results['total']}: {domain}"
325 )
327 if progress_callback:
328 progress_callback(
329 {
330 "current": i,
331 "total": results["total"],
332 "domain": domain,
333 "percentage": (i / results["total"]) * 100,
334 }
335 )
337 try:
338 # Check if already classified
339 if not force_update:
340 existing = (
341 session.query(DomainClassification)
342 .filter_by(domain=domain)
343 .first()
344 )
345 if existing:
346 results["skipped"] += 1
347 results["domains"].append(
348 {
349 "domain": domain,
350 "status": "skipped",
351 "category": existing.category,
352 "subcategory": existing.subcategory,
353 }
354 )
355 logger.info(
356 f"Domain {domain} already classified, skipping"
357 )
358 continue
360 # Classify this single domain
361 classification = self.classify_domain(
362 domain, force_update
363 )
365 if classification:
366 results["classified"] += 1
367 results["domains"].append(
368 {
369 "domain": domain,
370 "status": "classified",
371 "category": classification.category,
372 "subcategory": classification.subcategory,
373 "confidence": classification.confidence,
374 }
375 )
376 logger.info(
377 f"Successfully classified {domain} as {classification.category}"
378 )
379 else:
380 results["failed"] += 1
381 results["domains"].append(
382 {"domain": domain, "status": "failed"}
383 )
384 logger.warning(
385 f"Failed to classify domain {domain}"
386 )
388 except Exception:
389 logger.exception(f"Error classifying domain {domain}")
390 results["failed"] += 1
391 results["domains"].append(
392 {
393 "domain": domain,
394 "status": "failed",
395 "error": "Classification failed",
396 }
397 )
399 logger.info(
400 f"Classification complete: {results['classified']} classified, {results['skipped']} skipped, {results['failed']} failed"
401 )
402 return results
404 except Exception:
405 logger.exception("Error in classify_all_domains")
406 results["error"] = "Classification failed"
407 return results
409 def get_classification(self, domain: str) -> Optional[DomainClassification]:
410 """Get existing classification for a domain.
412 Args:
413 domain: Domain to look up
415 Returns:
416 DomainClassification object or None if not found
417 """
418 try:
419 with get_user_db_session(self.username) as session:
420 return (
421 session.query(DomainClassification)
422 .filter_by(domain=domain)
423 .first()
424 )
425 except Exception:
426 logger.exception(f"Error getting classification for {domain}")
427 return None
429 def get_all_classifications(self) -> List[DomainClassification]:
430 """Get all domain classifications.
432 Returns:
433 List of all DomainClassification objects
434 """
435 try:
436 with get_user_db_session(self.username) as session:
437 return (
438 session.query(DomainClassification)
439 .order_by(
440 DomainClassification.category,
441 DomainClassification.domain,
442 )
443 .all()
444 )
445 except Exception:
446 logger.exception("Error getting all classifications")
447 return []
449 def get_categories_summary(self) -> Dict:
450 """Get summary of domain classifications by category.
452 Returns:
453 Dictionary with category counts and domains
454 """
455 try:
456 with get_user_db_session(self.username) as session:
457 classifications = session.query(DomainClassification).all()
459 summary = {}
460 for classification in classifications:
461 cat = classification.category
462 if cat not in summary: 462 ↛ 469line 462 didn't jump to line 469 because the condition on line 462 was always true
463 summary[cat] = {
464 "count": 0,
465 "domains": [],
466 "subcategories": {},
467 }
469 summary[cat]["count"] += 1
470 summary[cat]["domains"].append(
471 {
472 "domain": classification.domain,
473 "subcategory": classification.subcategory,
474 "confidence": classification.confidence,
475 }
476 )
478 subcat = classification.subcategory
479 if subcat: 479 ↛ 460line 479 didn't jump to line 460 because the condition on line 479 was always true
480 if subcat not in summary[cat]["subcategories"]: 480 ↛ 482line 480 didn't jump to line 482 because the condition on line 480 was always true
481 summary[cat]["subcategories"][subcat] = 0
482 summary[cat]["subcategories"][subcat] += 1
484 return summary
486 except Exception:
487 logger.exception("Error getting categories summary")
488 return {}