Coverage for src / local_deep_research / domain_classifier / classifier.py: 44%

158 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1"""Domain classifier using LLM for categorization.""" 

2 

3import json 

4from typing import List, Dict, Optional 

5from loguru import logger 

6from sqlalchemy.orm import Session 

7from ..database.session_context import get_user_db_session 

8from ..database.models import ResearchResource 

9from .models import DomainClassification 

10from ..config.llm_config import get_llm 

11 

12 

13# Predefined categories for domain classification 

14DOMAIN_CATEGORIES = { 

15 "Academic & Research": [ 

16 "University/Education", 

17 "Scientific Journal", 

18 "Research Institution", 

19 "Academic Database", 

20 ], 

21 "News & Media": [ 

22 "General News", 

23 "Tech News", 

24 "Business News", 

25 "Entertainment News", 

26 "Local/Regional News", 

27 ], 

28 "Reference & Documentation": [ 

29 "Encyclopedia", 

30 "Technical Documentation", 

31 "API Documentation", 

32 "Tutorial/Guide", 

33 "Dictionary/Glossary", 

34 ], 

35 "Social & Community": [ 

36 "Social Network", 

37 "Forum/Discussion", 

38 "Q&A Platform", # Question-and-answer focused sites (StackOverflow, Quora) 

39 "Blog Platform", # Structured publishing platforms (Medium, WordPress.com) 

40 "Personal Blog", # Individual author blogs 

41 ], 

42 "Business & Commerce": [ 

43 "E-commerce", 

44 "Corporate Website", 

45 "B2B Platform", 

46 "Financial Service", 

47 "Marketing/Advertising", 

48 ], 

49 "Technology": [ 

50 "Software Development", 

51 "Cloud Service", 

52 "Open Source Project", 

53 "Tech Company", 

54 "Developer Tools", 

55 ], 

56 "Government & Organization": [ 

57 "Government Agency", 

58 "Non-profit", 

59 "International Organization", 

60 "Think Tank", 

61 "Industry Association", 

62 ], 

63 "Entertainment & Lifestyle": [ 

64 "Streaming Service", 

65 "Gaming", 

66 "Sports", 

67 "Arts & Culture", 

68 "Travel & Tourism", 

69 ], 

70 "Professional & Industry": [ 

71 "Healthcare", 

72 "Legal", 

73 "Real Estate", 

74 "Manufacturing", 

75 "Energy & Utilities", 

76 ], 

77 "Other": ["Personal Website", "Miscellaneous", "Unknown"], 

78} 

79 

80 

81class DomainClassifier: 

82 """Classify domains using LLM with predefined categories.""" 

83 

84 def __init__(self, username: str, settings_snapshot: dict = None): 

85 """Initialize the domain classifier. 

86 

87 Args: 

88 username: Username for database session 

89 settings_snapshot: Settings snapshot for LLM configuration 

90 """ 

91 self.username = username 

92 self.settings_snapshot = settings_snapshot 

93 self.llm = None 

94 

95 def _get_llm(self): 

96 """Get or initialize LLM instance.""" 

97 if self.llm is None: 

98 self.llm = get_llm(settings_snapshot=self.settings_snapshot) 

99 return self.llm 

100 

101 def _get_domain_samples( 

102 self, domain: str, session: Session, limit: int = 5 

103 ) -> List[Dict]: 

104 """Get sample resources from a domain. 

105 

106 Args: 

107 domain: Domain to get samples for 

108 session: Database session 

109 limit: Maximum number of samples 

110 

111 Returns: 

112 List of resource samples 

113 """ 

114 resources = ( 

115 session.query(ResearchResource) 

116 .filter(ResearchResource.url.like(f"%{domain}%")) 

117 .limit(limit) 

118 .all() 

119 ) 

120 

121 samples = [] 

122 for resource in resources: 

123 samples.append( 

124 { 

125 "title": resource.title or "Untitled", 

126 "url": resource.url, 

127 "preview": resource.content_preview[:200] 

128 if resource.content_preview 

129 else None, 

130 } 

131 ) 

132 

133 return samples 

134 

135 def _build_classification_prompt( 

136 self, domain: str, samples: List[Dict] 

137 ) -> str: 

138 """Build prompt for LLM classification. 

139 

140 This method uses actual content samples (titles, previews) from the domain 

141 rather than relying solely on domain name patterns, providing more 

142 accurate classification based on actual site content. 

143 

144 Args: 

145 domain: Domain to classify 

146 samples: Sample resources from the domain (titles, URLs, content previews) 

147 

148 Returns: 

149 Formatted prompt string 

150 """ 

151 # Format categories for prompt 

152 categories_text = [] 

153 for main_cat, subcats in DOMAIN_CATEGORIES.items(): 

154 subcats_text = ", ".join(subcats) 

155 categories_text.append(f"{main_cat}: {subcats_text}") 

156 

157 # Format samples 

158 samples_text = [] 

159 for i, sample in enumerate(samples[:5], 1): 

160 samples_text.append(f"{i}. Title: {sample['title']}") 

161 if sample.get("preview"): 161 ↛ 159line 161 didn't jump to line 159 because the condition on line 161 was always true

162 samples_text.append(f" Preview: {sample['preview'][:100]}...") 

163 

164 prompt = f"""Classify the following domain into one of the predefined categories. 

165 

166Domain: {domain} 

167 

168Sample content from this domain: 

169{chr(10).join(samples_text) if samples_text else "No samples available"} 

170 

171Available Categories: 

172{chr(10).join(categories_text)} 

173 

174Respond with a JSON object containing: 

175- "category": The main category (e.g., "News & Media") 

176- "subcategory": The specific subcategory (e.g., "Tech News") 

177- "confidence": A confidence score between 0 and 1 

178- "reasoning": A brief explanation (max 100 words) of why this classification was chosen 

179 

180Focus on accuracy. If uncertain, use "Other" category with "Unknown" subcategory. 

181 

182JSON Response:""" 

183 

184 return prompt 

185 

186 def classify_domain( 

187 self, domain: str, force_update: bool = False 

188 ) -> Optional[DomainClassification]: 

189 """Classify a single domain using LLM. 

190 

191 Args: 

192 domain: Domain to classify 

193 force_update: If True, reclassify even if already exists 

194 

195 Returns: 

196 DomainClassification object or None if failed 

197 """ 

198 try: 

199 with get_user_db_session(self.username) as session: 

200 # Check if already classified 

201 existing = ( 

202 session.query(DomainClassification) 

203 .filter_by(domain=domain) 

204 .first() 

205 ) 

206 

207 if existing and not force_update: 207 ↛ 214line 207 didn't jump to line 214 because the condition on line 207 was always true

208 logger.info( 

209 f"Domain {domain} already classified as {existing.category}" 

210 ) 

211 return existing 

212 

213 # Get sample resources 

214 samples = self._get_domain_samples(domain, session) 

215 

216 # Build prompt and get classification 

217 prompt = self._build_classification_prompt(domain, samples) 

218 llm = self._get_llm() 

219 

220 response = llm.invoke(prompt) 

221 

222 # Extract content from response 

223 if hasattr(response, "content"): 

224 response_text = response.content 

225 else: 

226 response_text = str(response) 

227 

228 # Parse JSON response 

229 try: 

230 result = json.loads(response_text) 

231 except json.JSONDecodeError: 

232 # Try to extract JSON from response 

233 import re 

234 

235 json_match = re.search(r"\{.*\}", response_text, re.DOTALL) 

236 if json_match: 

237 result = json.loads(json_match.group()) 

238 else: 

239 raise ValueError( 

240 "Could not parse JSON from LLM response" 

241 ) 

242 

243 # Create or update classification 

244 if existing: 

245 existing.category = result.get("category", "Other") 

246 existing.subcategory = result.get("subcategory", "Unknown") 

247 existing.confidence = float(result.get("confidence", 0.5)) 

248 existing.reasoning = result.get("reasoning", "") 

249 existing.sample_titles = json.dumps( 

250 [s["title"] for s in samples] 

251 ) 

252 existing.sample_count = len(samples) 

253 classification = existing 

254 else: 

255 classification = DomainClassification( 

256 domain=domain, 

257 category=result.get("category", "Other"), 

258 subcategory=result.get("subcategory", "Unknown"), 

259 confidence=float(result.get("confidence", 0.5)), 

260 reasoning=result.get("reasoning", ""), 

261 sample_titles=json.dumps([s["title"] for s in samples]), 

262 sample_count=len(samples), 

263 ) 

264 session.add(classification) 

265 

266 session.commit() 

267 logger.info( 

268 f"Classified {domain} as {classification.category}/{classification.subcategory} with confidence {classification.confidence}" 

269 ) 

270 return classification 

271 

272 except Exception: 

273 logger.exception(f"Error classifying domain {domain}") 

274 return None 

275 

276 def classify_all_domains( 

277 self, force_update: bool = False, progress_callback=None 

278 ) -> Dict: 

279 """Classify all unique domains in the database. 

280 

281 Args: 

282 force_update: If True, reclassify all domains 

283 progress_callback: Optional callback function for progress updates 

284 

285 Returns: 

286 Dictionary with classification results 

287 """ 

288 results = { 

289 "total": 0, 

290 "classified": 0, 

291 "failed": 0, 

292 "skipped": 0, 

293 "domains": [], 

294 } 

295 

296 try: 

297 with get_user_db_session(self.username) as session: 

298 # Get all unique domains 

299 from urllib.parse import urlparse 

300 

301 resources = session.query(ResearchResource.url).distinct().all() 

302 domains = set() 

303 

304 for (url,) in resources: 

305 if url: 

306 try: 

307 parsed = urlparse(url) 

308 domain = parsed.netloc.lower() 

309 if domain.startswith("www."): 

310 domain = domain[4:] 

311 if domain: 

312 domains.add(domain) 

313 except: 

314 continue 

315 

316 results["total"] = len(domains) 

317 logger.info( 

318 f"Found {results['total']} unique domains to process" 

319 ) 

320 

321 # Classify each domain ONE BY ONE 

322 for i, domain in enumerate(sorted(domains), 1): 

323 logger.info( 

324 f"Processing domain {i}/{results['total']}: {domain}" 

325 ) 

326 

327 if progress_callback: 

328 progress_callback( 

329 { 

330 "current": i, 

331 "total": results["total"], 

332 "domain": domain, 

333 "percentage": (i / results["total"]) * 100, 

334 } 

335 ) 

336 

337 try: 

338 # Check if already classified 

339 if not force_update: 

340 existing = ( 

341 session.query(DomainClassification) 

342 .filter_by(domain=domain) 

343 .first() 

344 ) 

345 if existing: 

346 results["skipped"] += 1 

347 results["domains"].append( 

348 { 

349 "domain": domain, 

350 "status": "skipped", 

351 "category": existing.category, 

352 "subcategory": existing.subcategory, 

353 } 

354 ) 

355 logger.info( 

356 f"Domain {domain} already classified, skipping" 

357 ) 

358 continue 

359 

360 # Classify this single domain 

361 classification = self.classify_domain( 

362 domain, force_update 

363 ) 

364 

365 if classification: 

366 results["classified"] += 1 

367 results["domains"].append( 

368 { 

369 "domain": domain, 

370 "status": "classified", 

371 "category": classification.category, 

372 "subcategory": classification.subcategory, 

373 "confidence": classification.confidence, 

374 } 

375 ) 

376 logger.info( 

377 f"Successfully classified {domain} as {classification.category}" 

378 ) 

379 else: 

380 results["failed"] += 1 

381 results["domains"].append( 

382 {"domain": domain, "status": "failed"} 

383 ) 

384 logger.warning( 

385 f"Failed to classify domain {domain}" 

386 ) 

387 

388 except Exception: 

389 logger.exception(f"Error classifying domain {domain}") 

390 results["failed"] += 1 

391 results["domains"].append( 

392 { 

393 "domain": domain, 

394 "status": "failed", 

395 "error": "Classification failed", 

396 } 

397 ) 

398 

399 logger.info( 

400 f"Classification complete: {results['classified']} classified, {results['skipped']} skipped, {results['failed']} failed" 

401 ) 

402 return results 

403 

404 except Exception: 

405 logger.exception("Error in classify_all_domains") 

406 results["error"] = "Classification failed" 

407 return results 

408 

409 def get_classification(self, domain: str) -> Optional[DomainClassification]: 

410 """Get existing classification for a domain. 

411 

412 Args: 

413 domain: Domain to look up 

414 

415 Returns: 

416 DomainClassification object or None if not found 

417 """ 

418 try: 

419 with get_user_db_session(self.username) as session: 

420 return ( 

421 session.query(DomainClassification) 

422 .filter_by(domain=domain) 

423 .first() 

424 ) 

425 except Exception: 

426 logger.exception(f"Error getting classification for {domain}") 

427 return None 

428 

429 def get_all_classifications(self) -> List[DomainClassification]: 

430 """Get all domain classifications. 

431 

432 Returns: 

433 List of all DomainClassification objects 

434 """ 

435 try: 

436 with get_user_db_session(self.username) as session: 

437 return ( 

438 session.query(DomainClassification) 

439 .order_by( 

440 DomainClassification.category, 

441 DomainClassification.domain, 

442 ) 

443 .all() 

444 ) 

445 except Exception: 

446 logger.exception("Error getting all classifications") 

447 return [] 

448 

449 def get_categories_summary(self) -> Dict: 

450 """Get summary of domain classifications by category. 

451 

452 Returns: 

453 Dictionary with category counts and domains 

454 """ 

455 try: 

456 with get_user_db_session(self.username) as session: 

457 classifications = session.query(DomainClassification).all() 

458 

459 summary = {} 

460 for classification in classifications: 

461 cat = classification.category 

462 if cat not in summary: 462 ↛ 469line 462 didn't jump to line 469 because the condition on line 462 was always true

463 summary[cat] = { 

464 "count": 0, 

465 "domains": [], 

466 "subcategories": {}, 

467 } 

468 

469 summary[cat]["count"] += 1 

470 summary[cat]["domains"].append( 

471 { 

472 "domain": classification.domain, 

473 "subcategory": classification.subcategory, 

474 "confidence": classification.confidence, 

475 } 

476 ) 

477 

478 subcat = classification.subcategory 

479 if subcat: 479 ↛ 460line 479 didn't jump to line 460 because the condition on line 479 was always true

480 if subcat not in summary[cat]["subcategories"]: 480 ↛ 482line 480 didn't jump to line 482 because the condition on line 480 was always true

481 summary[cat]["subcategories"][subcat] = 0 

482 summary[cat]["subcategories"][subcat] += 1 

483 

484 return summary 

485 

486 except Exception: 

487 logger.exception("Error getting categories summary") 

488 return {}