christopher commited on
Commit
c8d57fb
·
1 Parent(s): 1db196f

changed nlp and query processors to fix issues with lists

Browse files
Files changed (2) hide show
  1. database/query_processor.py +59 -44
  2. models/nlp.py +5 -9
database/query_processor.py CHANGED
@@ -2,6 +2,9 @@ import datetime
2
  from typing import List, Dict, Any, Optional
3
  import numpy as np
4
  from models.LexRank import degree_centrality_scores
 
 
 
5
 
6
  class QueryProcessor:
7
  def __init__(self, embedding_model, summarization_model, nlp_model, db_service):
@@ -17,51 +20,63 @@ class QueryProcessor:
17
  start_date: Optional[str] = None,
18
  end_date: Optional[str] = None
19
  ) -> Dict[str, Any]:
20
- # Convert string dates to datetime objects
21
- start_dt = datetime.strptime(start_date, "%Y-%m-%d") if start_date else None
22
- end_dt = datetime.strptime(end_date, "%Y-%m-%d") if end_date else None
23
-
24
- # Get query embedding
25
- query_embedding = self.embedding_model.encode(query).tolist()
 
 
26
 
27
- # Get entities from the query
28
- doc = self.nlp_model(query)
29
- entities = [ent.text.lower() for ent in doc.ents] # Extract entity texts
30
 
31
- # Semantic search with entities
32
- articles = await self.db_service.semantic_search(
33
- query_embedding=query_embedding,
34
- start_date=start_dt,
35
- end_date=end_dt,
36
- topic=topic,
37
- entities=entities # Pass entities to the search
38
- )
39
-
40
- if not articles:
41
- return {"error": "No articles found matching the criteria"}
42
-
43
- # Step 3: Process results
44
- contents = [article["content"] for article in articles]
45
- sentences = []
46
- for content in contents:
47
- sentences.extend(self.nlp_model.tokenize_sentences(content))
48
-
49
- # Step 4: Generate summary
50
- if sentences:
51
- embeddings = self.embedding_model.encode(sentences)
52
- similarity_matrix = np.inner(embeddings, embeddings)
53
- centrality_scores = degree_centrality_scores(similarity_matrix, threshold=None)
54
 
55
- top_indices = np.argsort(-centrality_scores)[0:10]
56
- key_sentences = [sentences[idx].strip() for idx in top_indices]
57
- combined_text = ' '.join(key_sentences)
 
 
58
 
59
- summary = self.summarization_model.summarize(combined_text)
60
- else:
61
- key_sentences = []
62
- summary = "No content available for summarization"
63
-
64
- return {
65
- "summary": summary,
66
- "articles": articles
67
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from typing import List, Dict, Any, Optional
3
  import numpy as np
4
  from models.LexRank import degree_centrality_scores
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
 
9
  class QueryProcessor:
10
  def __init__(self, embedding_model, summarization_model, nlp_model, db_service):
 
20
  start_date: Optional[str] = None,
21
  end_date: Optional[str] = None
22
  ) -> Dict[str, Any]:
23
+ try:
24
+ # Convert string dates to datetime objects
25
+ start_dt = datetime.strptime(start_date, "%Y-%m-%d") if start_date else None
26
+ end_dt = datetime.strptime(end_date, "%Y-%m-%d") if end_date else None
27
+
28
+ # Get query embedding
29
+ query_embedding = self.embedding_model.encode(query).tolist()
30
+ logger.debug(f"Generated query embedding for: {query[:50]}...")
31
 
32
+ # Extract entities using the NLP model
33
+ entities = self.nlp_model.extract_entities(query) # Changed from direct call to using method
34
+ logger.debug(f"Extracted entities: {entities}")
35
 
36
+ # Semantic search with entities
37
+ articles = await self.db_service.semantic_search(
38
+ query_embedding=query_embedding,
39
+ start_date=start_dt,
40
+ end_date=end_dt,
41
+ topic=topic,
42
+ entities=[ent[0] for ent in entities] # Using just the entity texts
43
+ )
44
+
45
+ if not articles:
46
+ logger.info("No articles found matching search criteria")
47
+ return {"error": "No articles found matching the criteria"}
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Process results
50
+ contents = [article["content"] for article in articles]
51
+ sentences = []
52
+ for content in contents:
53
+ sentences.extend(self.nlp_model.tokenize_sentences(content))
54
 
55
+ logger.debug(f"Processing {len(sentences)} sentences for summarization")
56
+
57
+ # Generate summary
58
+ if sentences:
59
+ embeddings = self.embedding_model.encode(sentences)
60
+ similarity_matrix = np.inner(embeddings, embeddings)
61
+ centrality_scores = degree_centrality_scores(similarity_matrix, threshold=None)
62
+
63
+ top_indices = np.argsort(-centrality_scores)[0:10]
64
+ key_sentences = [sentences[idx].strip() for idx in top_indices]
65
+ combined_text = ' '.join(key_sentences)
66
+
67
+ summary = self.summarization_model.summarize(combined_text)
68
+ logger.debug(f"Generated summary with {len(key_sentences)} key sentences")
69
+ else:
70
+ key_sentences = []
71
+ summary = "No content available for summarization"
72
+ logger.warning("No sentences available for summarization")
73
+
74
+ return {
75
+ "summary": summary,
76
+ "articles": articles,
77
+ "entities": entities # Include extracted entities in response
78
+ }
79
+
80
+ except Exception as e:
81
+ logger.error(f"Error in QueryProcessor: {str(e)}", exc_info=True)
82
+ return {"error": f"Processing error: {str(e)}"}
models/nlp.py CHANGED
@@ -11,15 +11,11 @@ class NLPModel:
11
  return self.extract_entities(text) # or another default method
12
 
13
  def extract_entities(self, text: str):
14
- if isinstance(text, list): # If input is a list of sentences
15
- entities = []
16
- for sentence in text:
17
- doc = self.nlp(sentence)
18
- entities.extend([(ent.text.lower(), ent.label_) for ent in doc.ents])
19
- return entities
20
- else: # If input is a single string
21
- doc = self.nlp(text)
22
- return [(ent.text.lower(), ent.label_) for ent in doc.ents]
23
 
24
 
25
  def tokenize_sentences(self, text: str):
 
11
  return self.extract_entities(text) # or another default method
12
 
13
  def extract_entities(self, text: str):
14
+ """Ensure this always takes a string and returns entities"""
15
+ if isinstance(text, list): # If accidentally passed a list
16
+ text = " ".join(text) # Combine into single string
17
+ doc = self.nlp(text)
18
+ return [(ent.text.lower(), ent.label_) for ent in doc.ents]
 
 
 
 
19
 
20
 
21
  def tokenize_sentences(self, text: str):