test / database /query_processor.py
christopher
Added safe processing in query processor
e67b064
raw
history blame
5.31 kB
import datetime
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from models.LexRank import degree_centrality_scores
import logging
from datetime import datetime as dt
logger = logging.getLogger(__name__)
class QueryProcessor:
def __init__(self, embedding_model, summarization_model, nlp_model, db_service):
self.embedding_model = embedding_model
self.summarization_model = summarization_model
self.nlp_model = nlp_model
self.db_service = db_service
async def process(
self,
query: str,
topic: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> Dict[str, Any]:
try:
# Validate and parse dates
start_dt = self._parse_date(start_date) if start_date else None
end_dt = self._parse_date(end_date) if end_date else None
# Get query embedding
query_embedding = self.embedding_model.encode(query).tolist()
logger.debug(f"Query embedding generated for: {query[:50]}...")
# Extract entities safely
entities = self._extract_entities_safely(query)
logger.debug(f"Extracted entities: {entities}")
# Semantic search
articles = await self._execute_semantic_search(
query_embedding,
start_dt,
end_dt,
topic,
entities
)
if not articles:
logger.info("No articles found matching criteria")
return {"message": "No articles found", "articles": []}
# Process results and generate summary
summary_result = self._generate_summary(articles)
return {
"summary": summary_result["summary"],
"key_sentences": summary_result["key_sentences"],
"articles": articles,
"entities": entities
}
except Exception as e:
logger.error(f"Processing failed: {str(e)}", exc_info=True)
return {"error": str(e)}
def _parse_date(self, date_str: str) -> dt:
"""Safe date parsing with validation"""
try:
return dt.strptime(date_str, "%Y-%m-%d")
except ValueError as e:
logger.error(f"Invalid date format: {date_str}")
raise ValueError(f"Invalid date format. Expected YYYY-MM-DD, got {date_str}")
def _extract_entities_safely(self, text: str) -> List[Tuple[str, str]]:
"""Robust entity extraction handling both strings and lists"""
try:
if isinstance(text, list):
logger.warning("Received list input for entity extraction, joining to string")
text = " ".join(text)
return self.nlp_model.extract_entities(text)
except Exception as e:
logger.error(f"Entity extraction failed: {str(e)}")
return []
async def _execute_semantic_search(
self,
query_embedding: List[float],
start_date: Optional[dt],
end_date: Optional[dt],
topic: Optional[str],
entities: List[Tuple[str, str]]
) -> List[Dict[str, Any]]:
"""Execute search with proper error handling"""
try:
entity_texts = [ent[0] for ent in entities]
return await self.db_service.semantic_search(
query_embedding=query_embedding,
start_date=start_date,
end_date=end_date,
topic=topic,
entities=entity_texts
)
except Exception as e:
logger.error(f"Semantic search failed: {str(e)}")
raise
def _generate_summary(self, articles: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Generate summary from articles with fallback handling"""
try:
contents = [article["content"] for article in articles]
sentences = []
for content in contents:
if content:
sentences.extend(self.nlp_model.tokenize_sentences(content))
if not sentences:
logger.warning("No sentences available for summarization")
return {
"summary": "No content available for summarization",
"key_sentences": []
}
embeddings = self.embedding_model.encode(sentences)
similarity_matrix = np.inner(embeddings, embeddings)
centrality_scores = degree_centrality_scores(similarity_matrix, threshold=None)
top_indices = np.argsort(-centrality_scores)[:10]
key_sentences = [sentences[idx].strip() for idx in top_indices]
combined_text = ' '.join(key_sentences)
return {
"summary": self.summarization_model.summarize(combined_text),
"key_sentences": key_sentences
}
except Exception as e:
logger.error(f"Summary generation failed: {str(e)}")
return {
"summary": "Summary generation failed",
"key_sentences": []
}