Coverage for src / local_deep_research / domain_classifier / classifier.py: 90%

152 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 01:07 +0000

1"""Domain classifier using LLM for categorization.""" 

2 

3import json 

4from typing import List, Dict, Optional 

5from loguru import logger 

6from sqlalchemy.orm import Session 

7from ..database.session_context import get_user_db_session 

8from ..database.models import ResearchResource 

9from .models import DomainClassification 

10from ..config.llm_config import get_llm 

11from ..utilities.json_utils import extract_json, get_llm_response_text 

12 

13 

14# Predefined categories for domain classification 

15DOMAIN_CATEGORIES = { 

16 "Academic & Research": [ 

17 "University/Education", 

18 "Scientific Journal", 

19 "Research Institution", 

20 "Academic Database", 

21 ], 

22 "News & Media": [ 

23 "General News", 

24 "Tech News", 

25 "Business News", 

26 "Entertainment News", 

27 "Local/Regional News", 

28 ], 

29 "Reference & Documentation": [ 

30 "Encyclopedia", 

31 "Technical Documentation", 

32 "API Documentation", 

33 "Tutorial/Guide", 

34 "Dictionary/Glossary", 

35 ], 

36 "Social & Community": [ 

37 "Social Network", 

38 "Forum/Discussion", 

39 "Q&A Platform", # Question-and-answer focused sites (StackOverflow, Quora) 

40 "Blog Platform", # Structured publishing platforms (Medium, WordPress.com) 

41 "Personal Blog", # Individual author blogs 

42 ], 

43 "Business & Commerce": [ 

44 "E-commerce", 

45 "Corporate Website", 

46 "B2B Platform", 

47 "Financial Service", 

48 "Marketing/Advertising", 

49 ], 

50 "Technology": [ 

51 "Software Development", 

52 "Cloud Service", 

53 "Open Source Project", 

54 "Tech Company", 

55 "Developer Tools", 

56 ], 

57 "Government & Organization": [ 

58 "Government Agency", 

59 "Non-profit", 

60 "International Organization", 

61 "Think Tank", 

62 "Industry Association", 

63 ], 

64 "Entertainment & Lifestyle": [ 

65 "Streaming Service", 

66 "Gaming", 

67 "Sports", 

68 "Arts & Culture", 

69 "Travel & Tourism", 

70 ], 

71 "Professional & Industry": [ 

72 "Healthcare", 

73 "Legal", 

74 "Real Estate", 

75 "Manufacturing", 

76 "Energy & Utilities", 

77 ], 

78 "Other": ["Personal Website", "Miscellaneous", "Unknown"], 

79} 

80 

81 

82class DomainClassifier: 

83 """Classify domains using LLM with predefined categories.""" 

84 

85 def __init__(self, username: str, settings_snapshot: dict = None): 

86 """Initialize the domain classifier. 

87 

88 Args: 

89 username: Username for database session 

90 settings_snapshot: Settings snapshot for LLM configuration 

91 """ 

92 self.username = username 

93 self.settings_snapshot = settings_snapshot 

94 self.llm = None 

95 

96 def _get_llm(self): 

97 """Get or initialize LLM instance.""" 

98 if self.llm is None: 

99 self.llm = get_llm(settings_snapshot=self.settings_snapshot) 

100 return self.llm 

101 

102 def _get_domain_samples( 

103 self, domain: str, session: Session, limit: int = 5 

104 ) -> List[Dict]: 

105 """Get sample resources from a domain. 

106 

107 Args: 

108 domain: Domain to get samples for 

109 session: Database session 

110 limit: Maximum number of samples 

111 

112 Returns: 

113 List of resource samples 

114 """ 

115 resources = ( 

116 session.query(ResearchResource) 

117 .filter(ResearchResource.url.like(f"%{domain}%")) 

118 .limit(limit) 

119 .all() 

120 ) 

121 

122 samples = [] 

123 for resource in resources: 123 ↛ 124line 123 didn't jump to line 124 because the loop on line 123 never started

124 samples.append( 

125 { 

126 "title": resource.title or "Untitled", 

127 "url": resource.url, 

128 "preview": resource.content_preview[:200] 

129 if resource.content_preview 

130 else None, 

131 } 

132 ) 

133 

134 return samples 

135 

136 def _build_classification_prompt( 

137 self, domain: str, samples: List[Dict] 

138 ) -> str: 

139 """Build prompt for LLM classification. 

140 

141 This method uses actual content samples (titles, previews) from the domain 

142 rather than relying solely on domain name patterns, providing more 

143 accurate classification based on actual site content. 

144 

145 Args: 

146 domain: Domain to classify 

147 samples: Sample resources from the domain (titles, URLs, content previews) 

148 

149 Returns: 

150 Formatted prompt string 

151 """ 

152 # Format categories for prompt 

153 categories_text = [] 

154 for main_cat, subcats in DOMAIN_CATEGORIES.items(): 

155 subcats_text = ", ".join(subcats) 

156 categories_text.append(f"{main_cat}: {subcats_text}") 

157 

158 # Format samples 

159 samples_text = [] 

160 for i, sample in enumerate(samples[:5], 1): 

161 samples_text.append(f"{i}. Title: {sample['title']}") 

162 if sample.get("preview"): 162 ↛ 160line 162 didn't jump to line 160 because the condition on line 162 was always true

163 samples_text.append(f" Preview: {sample['preview'][:100]}...") 

164 

165 prompt = f"""Classify the following domain into one of the predefined categories. 

166 

167Domain: {domain} 

168 

169Sample content from this domain: 

170{chr(10).join(samples_text) if samples_text else "No samples available"} 

171 

172Available Categories: 

173{chr(10).join(categories_text)} 

174 

175Respond with a JSON object containing: 

176- "category": The main category (e.g., "News & Media") 

177- "subcategory": The specific subcategory (e.g., "Tech News") 

178- "confidence": A confidence score between 0 and 1 

179- "reasoning": A brief explanation (max 100 words) of why this classification was chosen 

180 

181Focus on accuracy. If uncertain, use "Other" category with "Unknown" subcategory. 

182 

183JSON Response:""" 

184 

185 return prompt 

186 

187 def classify_domain( 

188 self, domain: str, force_update: bool = False 

189 ) -> Optional[DomainClassification]: 

190 """Classify a single domain using LLM. 

191 

192 Args: 

193 domain: Domain to classify 

194 force_update: If True, reclassify even if already exists 

195 

196 Returns: 

197 DomainClassification object or None if failed 

198 """ 

199 try: 

200 with get_user_db_session(self.username) as session: 

201 # Check if already classified 

202 existing = ( 

203 session.query(DomainClassification) 

204 .filter_by(domain=domain) 

205 .first() 

206 ) 

207 

208 if existing and not force_update: 

209 logger.info( 

210 f"Domain {domain} already classified as {existing.category}" 

211 ) 

212 return existing 

213 

214 # Get sample resources 

215 samples = self._get_domain_samples(domain, session) 

216 

217 # Build prompt and get classification 

218 prompt = self._build_classification_prompt(domain, samples) 

219 llm = self._get_llm() 

220 

221 response = llm.invoke(prompt) 

222 response_text = get_llm_response_text(response) 

223 

224 result = extract_json(response_text, expected_type=dict) 

225 if result is None: 

226 raise ValueError("Could not parse JSON from LLM response") 

227 

228 # Create or update classification 

229 if existing: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true

230 existing.category = result.get("category", "Other") 

231 existing.subcategory = result.get("subcategory", "Unknown") 

232 existing.confidence = float(result.get("confidence", 0.5)) 

233 existing.reasoning = result.get("reasoning", "") 

234 existing.sample_titles = json.dumps( 

235 [s["title"] for s in samples] 

236 ) 

237 existing.sample_count = len(samples) 

238 classification = existing 

239 else: 

240 classification = DomainClassification( 

241 domain=domain, 

242 category=result.get("category", "Other"), 

243 subcategory=result.get("subcategory", "Unknown"), 

244 confidence=float(result.get("confidence", 0.5)), 

245 reasoning=result.get("reasoning", ""), 

246 sample_titles=json.dumps([s["title"] for s in samples]), 

247 sample_count=len(samples), 

248 ) 

249 session.add(classification) 

250 

251 session.commit() 

252 logger.info( 

253 f"Classified {domain} as {classification.category}/{classification.subcategory} with confidence {classification.confidence}" 

254 ) 

255 return classification 

256 

257 except Exception: 

258 logger.exception(f"Error classifying domain {domain}") 

259 return None 

260 

261 def classify_all_domains( 

262 self, force_update: bool = False, progress_callback=None 

263 ) -> Dict: 

264 """Classify all unique domains in the database. 

265 

266 Args: 

267 force_update: If True, reclassify all domains 

268 progress_callback: Optional callback function for progress updates 

269 

270 Returns: 

271 Dictionary with classification results 

272 """ 

273 results = { 

274 "total": 0, 

275 "classified": 0, 

276 "failed": 0, 

277 "skipped": 0, 

278 "domains": [], 

279 } 

280 

281 try: 

282 with get_user_db_session(self.username) as session: 

283 # Get all unique domains 

284 from urllib.parse import urlparse 

285 

286 resources = session.query(ResearchResource.url).distinct().all() 

287 domains = set() 

288 

289 for (url,) in resources: 

290 if url: 

291 try: 

292 parsed = urlparse(url) 

293 domain = parsed.netloc.lower() 

294 if domain.startswith("www."): 

295 domain = domain[4:] 

296 if domain: 

297 domains.add(domain) 

298 except (ValueError, AttributeError): 

299 continue 

300 

301 results["total"] = len(domains) 

302 logger.info( 

303 f"Found {results['total']} unique domains to process" 

304 ) 

305 

306 # Classify each domain ONE BY ONE 

307 for i, domain in enumerate(sorted(domains), 1): 

308 logger.info( 

309 f"Processing domain {i}/{results['total']}: {domain}" 

310 ) 

311 

312 if progress_callback: 

313 progress_callback( 

314 { 

315 "current": i, 

316 "total": results["total"], 

317 "domain": domain, 

318 "percentage": (i / results["total"]) * 100, 

319 } 

320 ) 

321 

322 try: 

323 # Check if already classified 

324 if not force_update: 

325 existing = ( 

326 session.query(DomainClassification) 

327 .filter_by(domain=domain) 

328 .first() 

329 ) 

330 if existing: 330 ↛ 346line 330 didn't jump to line 346 because the condition on line 330 was always true

331 results["skipped"] += 1 

332 results["domains"].append( 

333 { 

334 "domain": domain, 

335 "status": "skipped", 

336 "category": existing.category, 

337 "subcategory": existing.subcategory, 

338 } 

339 ) 

340 logger.info( 

341 f"Domain {domain} already classified, skipping" 

342 ) 

343 continue 

344 

345 # Classify this single domain 

346 classification = self.classify_domain( 

347 domain, force_update 

348 ) 

349 

350 if classification: 

351 results["classified"] += 1 

352 results["domains"].append( 

353 { 

354 "domain": domain, 

355 "status": "classified", 

356 "category": classification.category, 

357 "subcategory": classification.subcategory, 

358 "confidence": classification.confidence, 

359 } 

360 ) 

361 logger.info( 

362 f"Successfully classified {domain} as {classification.category}" 

363 ) 

364 else: 

365 results["failed"] += 1 

366 results["domains"].append( 

367 {"domain": domain, "status": "failed"} 

368 ) 

369 logger.warning( 

370 f"Failed to classify domain {domain}" 

371 ) 

372 

373 except Exception: 

374 logger.exception(f"Error classifying domain {domain}") 

375 results["failed"] += 1 

376 results["domains"].append( 

377 { 

378 "domain": domain, 

379 "status": "failed", 

380 "error": "Classification failed", 

381 } 

382 ) 

383 

384 logger.info( 

385 f"Classification complete: {results['classified']} classified, {results['skipped']} skipped, {results['failed']} failed" 

386 ) 

387 return results 

388 

389 except Exception: 

390 logger.exception("Error in classify_all_domains") 

391 results["error"] = "Classification failed" 

392 return results 

393 

394 def get_classification(self, domain: str) -> Optional[DomainClassification]: 

395 """Get existing classification for a domain. 

396 

397 Args: 

398 domain: Domain to look up 

399 

400 Returns: 

401 DomainClassification object or None if not found 

402 """ 

403 try: 

404 with get_user_db_session(self.username) as session: 

405 return ( 

406 session.query(DomainClassification) 

407 .filter_by(domain=domain) 

408 .first() 

409 ) 

410 except Exception: 

411 logger.exception(f"Error getting classification for {domain}") 

412 return None 

413 

414 def get_all_classifications(self) -> List[DomainClassification]: 

415 """Get all domain classifications. 

416 

417 Returns: 

418 List of all DomainClassification objects 

419 """ 

420 try: 

421 with get_user_db_session(self.username) as session: 

422 return ( 

423 session.query(DomainClassification) 

424 .order_by( 

425 DomainClassification.category, 

426 DomainClassification.domain, 

427 ) 

428 .all() 

429 ) 

430 except Exception: 

431 logger.exception("Error getting all classifications") 

432 return [] 

433 

434 def get_categories_summary(self) -> Dict: 

435 """Get summary of domain classifications by category. 

436 

437 Returns: 

438 Dictionary with category counts and domains 

439 """ 

440 try: 

441 with get_user_db_session(self.username) as session: 

442 classifications = session.query(DomainClassification).all() 

443 

444 summary = {} 

445 for classification in classifications: 

446 cat = classification.category 

447 if cat not in summary: 

448 summary[cat] = { 

449 "count": 0, 

450 "domains": [], 

451 "subcategories": {}, 

452 } 

453 

454 summary[cat]["count"] += 1 

455 summary[cat]["domains"].append( 

456 { 

457 "domain": classification.domain, 

458 "subcategory": classification.subcategory, 

459 "confidence": classification.confidence, 

460 } 

461 ) 

462 

463 subcat = classification.subcategory 

464 if subcat: 

465 if subcat not in summary[cat]["subcategories"]: 465 ↛ 467line 465 didn't jump to line 467 because the condition on line 465 was always true

466 summary[cat]["subcategories"][subcat] = 0 

467 summary[cat]["subcategories"][subcat] += 1 

468 

469 return summary 

470 

471 except Exception: 

472 logger.exception("Error getting categories summary") 

473 return {}