rahuln2002 commited on
Commit
14f8a5c
·
verified ·
1 Parent(s): b47d542

Update knowledgeassistant/components/summarization.py

Browse files
knowledgeassistant/components/summarization.py CHANGED
@@ -17,15 +17,42 @@ class DataSummarization:
17
 
18
  def summarize(self, input_text_path: str, min_length: int):
19
  try:
20
- pipe = pipeline("summarization", model="/app/models/bart-large-cnn")
 
 
 
 
21
  logging.info("Summarization Pipeline Successfully Setup")
 
22
  text = read_txt_file(input_text_path)
23
- summary = pipe(text, min_length = min_length, do_sample = False)
 
 
 
 
 
 
 
 
 
 
 
 
24
  logging.info("Text successfully summarized")
 
 
25
  write_txt_file(self.data_summarization_config.summarized_text_file_path, summary[0].get("summary_text"))
26
  logging.info("Successfully wrote summarized text")
 
 
 
 
 
 
 
27
  except Exception as e:
28
  raise KnowledgeAssistantException(e, sys)
 
29
 
30
  def initiate_data_summarization(self, input_text_path: str, min_length: int):
31
  try:
 
17
 
18
  def summarize(self, input_text_path: str, min_length: int):
19
  try:
20
+ model_path = "/app/models/bart-large-cnn"
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
23
+
24
+ pipe = pipeline("summarization", model=model_path, tokenizer=model_path)
25
  logging.info("Summarization Pipeline Successfully Setup")
26
+
27
  text = read_txt_file(input_text_path)
28
+
29
+ tokens = tokenizer.encode(text, truncation=True, max_length=1024, return_tensors="pt")
30
+
31
+ if len(tokens[0]) >= 1024:
32
+ logging.warning("Input text exceeded 1024 tokens. It has been truncated.")
33
+ truncated_text = tokenizer.decode(tokens[0], skip_special_tokens=True)
34
+ frontend_message = "Your input text exceeded the limit of 1024 tokens and has been truncated."
35
+ else:
36
+ truncated_text = text
37
+ frontend_message = ""
38
+
39
+ # Generate summary
40
+ summary = pipe(truncated_text, min_length=min_length, max_length=142, do_sample=False)
41
  logging.info("Text successfully summarized")
42
+
43
+ # Save summary
44
  write_txt_file(self.data_summarization_config.summarized_text_file_path, summary[0].get("summary_text"))
45
  logging.info("Successfully wrote summarized text")
46
+
47
+ # Return summary along with frontend message
48
+ return {
49
+ "summary": summary[0].get("summary_text"),
50
+ "warning": frontend_message
51
+ }
52
+
53
  except Exception as e:
54
  raise KnowledgeAssistantException(e, sys)
55
+
56
 
57
  def initiate_data_summarization(self, input_text_path: str, min_length: int):
58
  try: