Toumaima commited on
Commit
6061cbb
·
verified ·
1 Parent(s): 73f1372

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -2
app.py CHANGED
@@ -14,6 +14,7 @@ import gradio as gr
14
  import pandas as pd
15
  from spacy.cli import download
16
 
 
17
  class BasicAgent:
18
  def __init__(self):
19
  print("BasicAgent initialized.")
@@ -23,10 +24,39 @@ class BasicAgent:
23
  download("en_core_web_sm")
24
  self.spacy = spacy.load("en_core_web_sm")
25
  self.whisper_model = whisper.load_model("base")
26
- self.qa_pipeline = pipeline("question-answering")
27
  self.ner_pipeline = pipeline("ner", aggregation_strategy="simple")
28
- self.embedding_model = pipeline("feature-extraction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def extract_named_entities(self, text):
31
  entities = self.ner_pipeline(text)
32
  return [e["word"] for e in entities if e["entity_group"] == "PER"]
 
14
  import pandas as pd
15
  from spacy.cli import download
16
 
17
+
18
  class BasicAgent:
19
  def __init__(self):
20
  print("BasicAgent initialized.")
 
24
  download("en_core_web_sm")
25
  self.spacy = spacy.load("en_core_web_sm")
26
  self.whisper_model = whisper.load_model("base")
27
+ self.qa_pipeline = pipeline("question-answering", truncation=True, padding=True)
28
  self.ner_pipeline = pipeline("ner", aggregation_strategy="simple")
29
+ self.embedding_model = pipeline("feature-extraction", truncation=True)
30
+
31
+ def split_text_into_chunks(self, text, max_length=512):
32
+ """Split text into chunks smaller than `max_length` tokens."""
33
+ words = text.split()
34
+ chunks = []
35
+ chunk = []
36
+
37
+ for word in words:
38
+ chunk.append(word)
39
+ if len(' '.join(chunk)) > max_length:
40
+ chunks.append(' '.join(chunk[:-1])) # Add the chunk and reset
41
+ chunk = [word]
42
+
43
+ if chunk:
44
+ chunks.append(' '.join(chunk)) # Add the final chunk
45
+
46
+ return chunks
47
 
48
+ def answer_question(self, question: str, context: str) -> str:
49
+ try:
50
+ context_chunks = self.split_text_into_chunks(context, max_length=512)
51
+ answers = []
52
+ for chunk in context_chunks:
53
+ answer = self.qa_pipeline(question=question, context=chunk)["answer"]
54
+ answers.append(answer)
55
+
56
+ return " ".join(answers) # Combine answers from chunks
57
+ except Exception as e:
58
+ return f"Error answering question: {e}"
59
+
60
  def extract_named_entities(self, text):
61
  entities = self.ner_pipeline(text)
62
  return [e["word"] for e in entities if e["entity_group"] == "PER"]