Coverage for src / local_deep_research / domain_classifier / classifier.py: 93%

155 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1"""Domain classifier using LLM for categorization.""" 

2 

3import json 

4from typing import Dict, List, Optional 

5 

6from loguru import logger 

7from sqlalchemy.orm import Session 

8 

9from ..config.llm_config import get_llm 

10from ..database.models import ResearchResource 

11from ..database.session_context import get_user_db_session 

12from ..utilities.json_utils import extract_json, get_llm_response_text 

13from .models import DomainClassification 

14 

15 

16# Predefined categories for domain classification 

17DOMAIN_CATEGORIES = { 

18 "Academic & Research": [ 

19 "University/Education", 

20 "Scientific Journal", 

21 "Research Institution", 

22 "Academic Database", 

23 ], 

24 "News & Media": [ 

25 "General News", 

26 "Tech News", 

27 "Business News", 

28 "Entertainment News", 

29 "Local/Regional News", 

30 ], 

31 "Reference & Documentation": [ 

32 "Encyclopedia", 

33 "Technical Documentation", 

34 "API Documentation", 

35 "Tutorial/Guide", 

36 "Dictionary/Glossary", 

37 ], 

38 "Social & Community": [ 

39 "Social Network", 

40 "Forum/Discussion", 

41 "Q&A Platform", # Question-and-answer focused sites (StackOverflow, Quora) 

42 "Blog Platform", # Structured publishing platforms (Medium, WordPress.com) 

43 "Personal Blog", # Individual author blogs 

44 ], 

45 "Business & Commerce": [ 

46 "E-commerce", 

47 "Corporate Website", 

48 "B2B Platform", 

49 "Financial Service", 

50 "Marketing/Advertising", 

51 ], 

52 "Technology": [ 

53 "Software Development", 

54 "Cloud Service", 

55 "Open Source Project", 

56 "Tech Company", 

57 "Developer Tools", 

58 ], 

59 "Government & Organization": [ 

60 "Government Agency", 

61 "Non-profit", 

62 "International Organization", 

63 "Think Tank", 

64 "Industry Association", 

65 ], 

66 "Entertainment & Lifestyle": [ 

67 "Streaming Service", 

68 "Gaming", 

69 "Sports", 

70 "Arts & Culture", 

71 "Travel & Tourism", 

72 ], 

73 "Professional & Industry": [ 

74 "Healthcare", 

75 "Legal", 

76 "Real Estate", 

77 "Manufacturing", 

78 "Energy & Utilities", 

79 ], 

80 "Other": ["Personal Website", "Miscellaneous", "Unknown"], 

81} 

82 

83 

84class DomainClassifier: 

85 """Classify domains using LLM with predefined categories.""" 

86 

87 def __init__(self, username: str, settings_snapshot: dict = None): 

88 """Initialize the domain classifier. 

89 

90 Args: 

91 username: Username for database session 

92 settings_snapshot: Settings snapshot for LLM configuration 

93 """ 

94 self.username = username 

95 self.settings_snapshot = settings_snapshot 

96 self.llm = None 

97 

98 def _get_llm(self): 

99 """Get or initialize LLM instance.""" 

100 if self.llm is None: 

101 self.llm = get_llm(settings_snapshot=self.settings_snapshot) 

102 return self.llm 

103 

104 def close(self) -> None: 

105 """Close the LLM client if one was created.""" 

106 from ..utilities.resource_utils import safe_close 

107 

108 safe_close(self.llm, "classifier LLM") 

109 

110 def _get_domain_samples( 

111 self, domain: str, session: Session, limit: int = 5 

112 ) -> List[Dict]: 

113 """Get sample resources from a domain. 

114 

115 Args: 

116 domain: Domain to get samples for 

117 session: Database session 

118 limit: Maximum number of samples 

119 

120 Returns: 

121 List of resource samples 

122 """ 

123 resources = ( 

124 session.query(ResearchResource) 

125 .filter(ResearchResource.url.like(f"%{domain}%")) 

126 .limit(limit) 

127 .all() 

128 ) 

129 

130 samples = [] 

131 for resource in resources: 

132 samples.append( 

133 { 

134 "title": resource.title or "Untitled", 

135 "url": resource.url, 

136 "preview": resource.content_preview[:200] 

137 if resource.content_preview 

138 else None, 

139 } 

140 ) 

141 

142 return samples 

143 

144 def _build_classification_prompt( 

145 self, domain: str, samples: List[Dict] 

146 ) -> str: 

147 """Build prompt for LLM classification. 

148 

149 This method uses actual content samples (titles, previews) from the domain 

150 rather than relying solely on domain name patterns, providing more 

151 accurate classification based on actual site content. 

152 

153 Args: 

154 domain: Domain to classify 

155 samples: Sample resources from the domain (titles, URLs, content previews) 

156 

157 Returns: 

158 Formatted prompt string 

159 """ 

160 # Format categories for prompt 

161 categories_text = [] 

162 for main_cat, subcats in DOMAIN_CATEGORIES.items(): 

163 subcats_text = ", ".join(subcats) 

164 categories_text.append(f"{main_cat}: {subcats_text}") 

165 

166 # Format samples 

167 samples_text = [] 

168 for i, sample in enumerate(samples[:5], 1): 

169 samples_text.append(f"{i}. Title: {sample['title']}") 

170 if sample.get("preview"): 

171 samples_text.append(f" Preview: {sample['preview'][:100]}...") 

172 

173 return f"""Classify the following domain into one of the predefined categories. 

174 

175Domain: {domain} 

176 

177Sample content from this domain: 

178{chr(10).join(samples_text) if samples_text else "No samples available"} 

179 

180Available Categories: 

181{chr(10).join(categories_text)} 

182 

183Respond with a JSON object containing: 

184- "category": The main category (e.g., "News & Media") 

185- "subcategory": The specific subcategory (e.g., "Tech News") 

186- "confidence": A confidence score between 0 and 1 

187- "reasoning": A brief explanation (max 100 words) of why this classification was chosen 

188 

189Focus on accuracy. If uncertain, use "Other" category with "Unknown" subcategory. 

190 

191JSON Response:""" 

192 

193 def classify_domain( 

194 self, domain: str, force_update: bool = False 

195 ) -> Optional[DomainClassification]: 

196 """Classify a single domain using LLM. 

197 

198 Args: 

199 domain: Domain to classify 

200 force_update: If True, reclassify even if already exists 

201 

202 Returns: 

203 DomainClassification object or None if failed 

204 """ 

205 try: 

206 with get_user_db_session(self.username) as session: 

207 # Check if already classified 

208 existing = ( 

209 session.query(DomainClassification) 

210 .filter_by(domain=domain) 

211 .first() 

212 ) 

213 

214 if existing and not force_update: 

215 logger.info( 

216 f"Domain {domain} already classified as {existing.category}" 

217 ) 

218 return existing 

219 

220 # Get sample resources 

221 samples = self._get_domain_samples(domain, session) 

222 

223 # Build prompt and get classification 

224 prompt = self._build_classification_prompt(domain, samples) 

225 llm = self._get_llm() 

226 

227 response = llm.invoke(prompt) 

228 response_text = get_llm_response_text(response) 

229 

230 result = extract_json(response_text, expected_type=dict) 

231 if result is None: 

232 raise ValueError("Could not parse JSON from LLM response") # noqa: TRY301 — inside db session; except logs and returns None 

233 

234 # Create or update classification 

235 if existing: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true

236 existing.category = result.get("category", "Other") 

237 existing.subcategory = result.get("subcategory", "Unknown") 

238 existing.confidence = float(result.get("confidence", 0.5)) 

239 existing.reasoning = result.get("reasoning", "") 

240 existing.sample_titles = json.dumps( 

241 [s["title"] for s in samples] 

242 ) 

243 existing.sample_count = len(samples) 

244 classification = existing 

245 else: 

246 classification = DomainClassification( 

247 domain=domain, 

248 category=result.get("category", "Other"), 

249 subcategory=result.get("subcategory", "Unknown"), 

250 confidence=float(result.get("confidence", 0.5)), 

251 reasoning=result.get("reasoning", ""), 

252 sample_titles=json.dumps([s["title"] for s in samples]), 

253 sample_count=len(samples), 

254 ) 

255 session.add(classification) 

256 

257 session.commit() 

258 logger.info( 

259 f"Classified {domain} as {classification.category}/{classification.subcategory} with confidence {classification.confidence}" 

260 ) 

261 return classification 

262 

263 except Exception: 

264 logger.exception(f"Error classifying domain {domain}") 

265 return None 

266 

267 def classify_all_domains( 

268 self, force_update: bool = False, progress_callback=None 

269 ) -> Dict: 

270 """Classify all unique domains in the database. 

271 

272 Args: 

273 force_update: If True, reclassify all domains 

274 progress_callback: Optional callback function for progress updates 

275 

276 Returns: 

277 Dictionary with classification results 

278 """ 

279 results = { 

280 "total": 0, 

281 "classified": 0, 

282 "failed": 0, 

283 "skipped": 0, 

284 "domains": [], 

285 } 

286 

287 try: 

288 with get_user_db_session(self.username) as session: 

289 # Get all unique domains 

290 from urllib.parse import urlparse 

291 

292 resources = session.query(ResearchResource.url).distinct().all() 

293 domains = set() 

294 

295 for (url,) in resources: 

296 if url: 

297 try: 

298 parsed = urlparse(url) 

299 domain = parsed.netloc.lower() 

300 if domain.startswith("www."): 

301 domain = domain[4:] 

302 if domain: 

303 domains.add(domain) 

304 except (ValueError, AttributeError): 

305 logger.debug("Skipping malformed URL") 

306 continue 

307 

308 results["total"] = len(domains) 

309 logger.info( 

310 f"Found {results['total']} unique domains to process" 

311 ) 

312 

313 # Classify each domain ONE BY ONE 

314 for i, domain in enumerate(sorted(domains), 1): 

315 logger.info( 

316 f"Processing domain {i}/{results['total']}: {domain}" 

317 ) 

318 

319 if progress_callback: 

320 progress_callback( 

321 { 

322 "current": i, 

323 "total": results["total"], 

324 "domain": domain, 

325 "percentage": (i / results["total"]) * 100, 

326 } 

327 ) 

328 

329 try: 

330 # Check if already classified 

331 if not force_update: 

332 existing = ( 

333 session.query(DomainClassification) 

334 .filter_by(domain=domain) 

335 .first() 

336 ) 

337 if existing: 337 ↛ 353line 337 didn't jump to line 353 because the condition on line 337 was always true

338 results["skipped"] += 1 

339 results["domains"].append( 

340 { 

341 "domain": domain, 

342 "status": "skipped", 

343 "category": existing.category, 

344 "subcategory": existing.subcategory, 

345 } 

346 ) 

347 logger.info( 

348 f"Domain {domain} already classified, skipping" 

349 ) 

350 continue 

351 

352 # Classify this single domain 

353 classification = self.classify_domain( 

354 domain, force_update 

355 ) 

356 

357 if classification: 

358 results["classified"] += 1 

359 results["domains"].append( 

360 { 

361 "domain": domain, 

362 "status": "classified", 

363 "category": classification.category, 

364 "subcategory": classification.subcategory, 

365 "confidence": classification.confidence, 

366 } 

367 ) 

368 logger.info( 

369 f"Successfully classified {domain} as {classification.category}" 

370 ) 

371 else: 

372 results["failed"] += 1 

373 results["domains"].append( 

374 {"domain": domain, "status": "failed"} 

375 ) 

376 logger.warning( 

377 f"Failed to classify domain {domain}" 

378 ) 

379 

380 except Exception: 

381 logger.exception(f"Error classifying domain {domain}") 

382 results["failed"] += 1 

383 results["domains"].append( 

384 { 

385 "domain": domain, 

386 "status": "failed", 

387 "error": "Classification failed", 

388 } 

389 ) 

390 

391 logger.info( 

392 f"Classification complete: {results['classified']} classified, {results['skipped']} skipped, {results['failed']} failed" 

393 ) 

394 return results 

395 

396 except Exception: 

397 logger.exception("Error in classify_all_domains") 

398 results["error"] = "Classification failed" 

399 return results 

400 

401 def get_classification(self, domain: str) -> Optional[DomainClassification]: 

402 """Get existing classification for a domain. 

403 

404 Args: 

405 domain: Domain to look up 

406 

407 Returns: 

408 DomainClassification object or None if not found 

409 """ 

410 try: 

411 with get_user_db_session(self.username) as session: 

412 return ( 

413 session.query(DomainClassification) 

414 .filter_by(domain=domain) 

415 .first() 

416 ) 

417 except Exception: 

418 logger.exception(f"Error getting classification for {domain}") 

419 return None 

420 

421 def get_all_classifications(self) -> List[DomainClassification]: 

422 """Get all domain classifications. 

423 

424 Returns: 

425 List of all DomainClassification objects 

426 """ 

427 try: 

428 with get_user_db_session(self.username) as session: 

429 return ( 

430 session.query(DomainClassification) 

431 .order_by( 

432 DomainClassification.category, 

433 DomainClassification.domain, 

434 ) 

435 .all() 

436 ) 

437 except Exception: 

438 logger.exception("Error getting all classifications") 

439 return [] 

440 

441 def get_categories_summary(self) -> Dict: 

442 """Get summary of domain classifications by category. 

443 

444 Returns: 

445 Dictionary with category counts and domains 

446 """ 

447 try: 

448 with get_user_db_session(self.username) as session: 

449 classifications = session.query(DomainClassification).all() 

450 

451 summary = {} 

452 for classification in classifications: 

453 cat = classification.category 

454 if cat not in summary: 

455 summary[cat] = { 

456 "count": 0, 

457 "domains": [], 

458 "subcategories": {}, 

459 } 

460 

461 summary[cat]["count"] += 1 

462 summary[cat]["domains"].append( 

463 { 

464 "domain": classification.domain, 

465 "subcategory": classification.subcategory, 

466 "confidence": classification.confidence, 

467 } 

468 ) 

469 

470 subcat = classification.subcategory 

471 if subcat: 

472 if subcat not in summary[cat]["subcategories"]: 472 ↛ 474line 472 didn't jump to line 474 because the condition on line 472 was always true

473 summary[cat]["subcategories"][subcat] = 0 

474 summary[cat]["subcategories"][subcat] += 1 

475 

476 return summary 

477 

478 except Exception: 

479 logger.exception("Error getting categories summary") 

480 return {}