maahi2412 commited on
Commit
a8d28cf
·
verified ·
1 Parent(s): b0bb1a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +430 -106
app.py CHANGED
@@ -1,134 +1,458 @@
1
- from flask import Flask, request, jsonify
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import pdfplumber
4
- import pytesseract
5
  from PIL import Image
6
- from transformers import PegasusForConditionalGeneration, PegasusTokenizer
 
 
 
 
 
7
  import torch
8
- import logging
 
9
 
10
  app = Flask(__name__)
 
 
 
 
 
 
 
11
 
12
- # Set up logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
 
16
- # Load Pegasus Model (load once globally)
17
- logger.info("Loading Pegasus model and tokenizer...")
18
- tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
19
- model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum").to("cpu") # Force CPU to manage memory
20
- logger.info("Model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Extract text from PDF with page limit
23
- def extract_text_from_pdf(file_path, max_pages=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  text = ""
25
- try:
26
- with pdfplumber.open(file_path) as pdf:
27
- total_pages = len(pdf.pages)
28
- pages_to_process = min(total_pages, max_pages)
29
- logger.info(f"Extracting text from {pages_to_process} of {total_pages} pages in {file_path}")
30
- for i, page in enumerate(pdf.pages[:pages_to_process]):
31
- try:
32
- extracted = page.extract_text()
33
- if extracted:
34
- text += extracted + "\n"
35
- else:
36
- logger.info(f"No text on page {i+1}, attempting OCR...")
37
- image = page.to_image().original
38
- text += pytesseract.image_to_string(image) + "\n"
39
- except Exception as e:
40
- logger.warning(f"Error processing page {i+1}: {e}")
41
- continue
42
- except Exception as e:
43
- logger.error(f"Failed to process PDF {file_path}: {e}")
44
- return ""
45
- return text.strip()
46
 
47
- # Extract text from image (OCR)
48
  def extract_text_from_image(file_path):
49
- try:
50
- logger.info(f"Extracting text from image {file_path} using OCR...")
51
- image = Image.open(file_path)
52
- text = pytesseract.image_to_string(image)
53
- return text.strip()
54
- except Exception as e:
55
- logger.error(f"Failed to process image {file_path}: {e}")
56
- return ""
57
 
58
- # Summarize text with chunking for large inputs
59
- def summarize_text(text, max_input_length=512, max_output_length=150):
60
- try:
61
- logger.info("Summarizing text...")
62
- # Tokenize and truncate to max_input_length
63
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_input_length, padding=True)
64
- input_length = inputs["input_ids"].shape[1]
65
- logger.info(f"Input length: {input_length} tokens")
66
-
67
- # Adjust generation params for efficiency
68
- summary_ids = model.generate(
69
- inputs["input_ids"],
70
- max_length=max_output_length,
71
- min_length=30,
72
- num_beams=2, # Reduce beams for speedup
73
- early_stopping=True,
74
- length_penalty=1.0, # Encourage shorter outputs
75
- )
76
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
77
- logger.info("Summarization completed.")
78
- return summary
79
- except Exception as e:
80
- logger.error(f"Error during summarization: {e}")
81
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  @app.route('/summarize', methods=['POST'])
84
  def summarize_document():
85
  if 'file' not in request.files:
86
- logger.error("No file uploaded in request.")
87
  return jsonify({"error": "No file uploaded"}), 400
88
 
89
  file = request.files['file']
90
  filename = file.filename
91
- if not filename:
92
- logger.error("Empty filename in request.")
93
- return jsonify({"error": "No file uploaded"}), 400
94
-
95
- file_path = os.path.join("/tmp", filename)
 
96
  try:
97
  file.save(file_path)
98
- logger.info(f"File saved to {file_path}")
99
-
100
- if filename.lower().endswith('.pdf'):
101
- text = extract_text_from_pdf(file_path, max_pages=2) # Reduce to 2 pages
102
- elif filename.lower().endswith(('.png', '.jpeg', '.jpg')):
 
 
103
  text = extract_text_from_image(file_path)
104
  else:
105
- logger.error(f"Unsupported file format: {filename}")
106
- return jsonify({"error": "Unsupported file format. Use PDF, PNG, JPEG, or JPG"}), 400
107
-
108
- if not text:
109
- logger.warning(f"No text extracted from {filename}")
110
- return jsonify({"error": "No text extracted from the file"}), 400
111
-
112
- summary = summarize_text(text)
113
- if not summary:
114
- logger.warning("Summarization failed to produce output.")
115
- return jsonify({"error": "Failed to generate summary"}), 500
116
-
117
- logger.info(f"Summary generated for {filename}")
118
- return jsonify({"summary": summary})
119
-
120
  except Exception as e:
121
- logger.error(f"Unexpected error processing {filename}: {e}")
122
- return jsonify({"error": str(e)}), 500
123
-
124
- finally:
125
- if os.path.exists(file_path):
126
- try:
127
- os.remove(file_path)
128
- logger.info(f"Cleaned up file: {file_path}")
129
- except Exception as e:
130
- logger.warning(f"Failed to delete {file_path}: {e}")
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  if __name__ == '__main__':
133
- logger.info("Starting Flask app...")
134
- app.run(host='0.0.0.0', port=7860)
 
1
+ # from flask import Flask, request, jsonify
2
+ # import os
3
+ # import pdfplumber
4
+ # import pytesseract
5
+ # from PIL import Image
6
+ # from transformers import PegasusForConditionalGeneration, PegasusTokenizer
7
+ # import torch
8
+ # import logging
9
+
10
+ # app = Flask(__name__)
11
+
12
+ # # Set up logging
13
+ # logging.basicConfig(level=logging.INFO)
14
+ # logger = logging.getLogger(__name__)
15
+
16
+ # # Load Pegasus Model (load once globally)
17
+ # logger.info("Loading Pegasus model and tokenizer...")
18
+ # tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
19
+ # model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum").to("cpu") # Force CPU to manage memory
20
+ # logger.info("Model loaded successfully.")
21
+
22
+ # # Extract text from PDF with page limit
23
+ # def extract_text_from_pdf(file_path, max_pages=5):
24
+ # text = ""
25
+ # try:
26
+ # with pdfplumber.open(file_path) as pdf:
27
+ # total_pages = len(pdf.pages)
28
+ # pages_to_process = min(total_pages, max_pages)
29
+ # logger.info(f"Extracting text from {pages_to_process} of {total_pages} pages in {file_path}")
30
+ # for i, page in enumerate(pdf.pages[:pages_to_process]):
31
+ # try:
32
+ # extracted = page.extract_text()
33
+ # if extracted:
34
+ # text += extracted + "\n"
35
+ # else:
36
+ # logger.info(f"No text on page {i+1}, attempting OCR...")
37
+ # image = page.to_image().original
38
+ # text += pytesseract.image_to_string(image) + "\n"
39
+ # except Exception as e:
40
+ # logger.warning(f"Error processing page {i+1}: {e}")
41
+ # continue
42
+ # except Exception as e:
43
+ # logger.error(f"Failed to process PDF {file_path}: {e}")
44
+ # return ""
45
+ # return text.strip()
46
+
47
+ # # Extract text from image (OCR)
48
+ # def extract_text_from_image(file_path):
49
+ # try:
50
+ # logger.info(f"Extracting text from image {file_path} using OCR...")
51
+ # image = Image.open(file_path)
52
+ # text = pytesseract.image_to_string(image)
53
+ # return text.strip()
54
+ # except Exception as e:
55
+ # logger.error(f"Failed to process image {file_path}: {e}")
56
+ # return ""
57
+
58
+ # # Summarize text with chunking for large inputs
59
+ # def summarize_text(text, max_input_length=512, max_output_length=150):
60
+ # try:
61
+ # logger.info("Summarizing text...")
62
+ # # Tokenize and truncate to max_input_length
63
+ # inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_input_length, padding=True)
64
+ # input_length = inputs["input_ids"].shape[1]
65
+ # logger.info(f"Input length: {input_length} tokens")
66
+
67
+ # # Adjust generation params for efficiency
68
+ # summary_ids = model.generate(
69
+ # inputs["input_ids"],
70
+ # max_length=max_output_length,
71
+ # min_length=30,
72
+ # num_beams=2, # Reduce beams for speedup
73
+ # early_stopping=True,
74
+ # length_penalty=1.0, # Encourage shorter outputs
75
+ # )
76
+ # summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
77
+ # logger.info("Summarization completed.")
78
+ # return summary
79
+ # except Exception as e:
80
+ # logger.error(f"Error during summarization: {e}")
81
+ # return ""
82
+
83
+ # @app.route('/summarize', methods=['POST'])
84
+ # def summarize_document():
85
+ # if 'file' not in request.files:
86
+ # logger.error("No file uploaded in request.")
87
+ # return jsonify({"error": "No file uploaded"}), 400
88
+
89
+ # file = request.files['file']
90
+ # filename = file.filename
91
+ # if not filename:
92
+ # logger.error("Empty filename in request.")
93
+ # return jsonify({"error": "No file uploaded"}), 400
94
+
95
+ # file_path = os.path.join("/tmp", filename)
96
+ # try:
97
+ # file.save(file_path)
98
+ # logger.info(f"File saved to {file_path}")
99
+
100
+ # if filename.lower().endswith('.pdf'):
101
+ # text = extract_text_from_pdf(file_path, max_pages=2) # Reduce to 2 pages
102
+ # elif filename.lower().endswith(('.png', '.jpeg', '.jpg')):
103
+ # text = extract_text_from_image(file_path)
104
+ # else:
105
+ # logger.error(f"Unsupported file format: {filename}")
106
+ # return jsonify({"error": "Unsupported file format. Use PDF, PNG, JPEG, or JPG"}), 400
107
+
108
+ # if not text:
109
+ # logger.warning(f"No text extracted from {filename}")
110
+ # return jsonify({"error": "No text extracted from the file"}), 400
111
+
112
+ # summary = summarize_text(text)
113
+ # if not summary:
114
+ # logger.warning("Summarization failed to produce output.")
115
+ # return jsonify({"error": "Failed to generate summary"}), 500
116
+
117
+ # logger.info(f"Summary generated for {filename}")
118
+ # return jsonify({"summary": summary})
119
+
120
+ # except Exception as e:
121
+ # logger.error(f"Unexpected error processing {filename}: {e}")
122
+ # return jsonify({"error": str(e)}), 500
123
+
124
+ # finally:
125
+ # if os.path.exists(file_path):
126
+ # try:
127
+ # os.remove(file_path)
128
+ # logger.info(f"Cleaned up file: {file_path}")
129
+ # except Exception as e:
130
+ # logger.warning(f"Failed to delete {file_path}: {e}")
131
+
132
+ # if __name__ == '__main__':
133
+ # logger.info("Starting Flask app...")
134
+ # app.run(host='0.0.0.0', port=7860)
135
+
136
+
137
  import os
138
  import pdfplumber
 
139
  from PIL import Image
140
+ import pytesseract
141
+ import numpy as np
142
+ from flask import Flask, request, jsonify
143
+ from flask_cors import CORS
144
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
145
+ from datasets import load_dataset, concatenate_datasets
146
  import torch
147
+ from sklearn.feature_extraction.text import TfidfVectorizer
148
+ from sklearn.metrics.pairwise import cosine_similarity
149
 
150
  app = Flask(__name__)
151
+ CORS(app)
152
+ UPLOAD_FOLDER = 'uploads'
153
+ PEGASUS_MODEL_DIR = 'fine_tuned_pegasus'
154
+ BERT_MODEL_DIR = 'fine_tuned_bert'
155
+ LEGALBERT_MODEL_DIR = 'fine_tuned_legalbert'
156
+ MAX_FILE_SIZE = 100 * 1024 * 1024
157
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
158
 
159
+ transformers.logging.set_verbosity_error()
160
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
 
161
 
162
+ # Pegasus Fine-Tuning
163
+ def load_or_finetune_pegasus():
164
+ if os.path.exists(PEGASUS_MODEL_DIR):
165
+ print("Loading fine-tuned Pegasus model...")
166
+ tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_MODEL_DIR)
167
+ model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_MODEL_DIR)
168
+ else:
169
+ print("Fine-tuning Pegasus on CNN/Daily Mail and XSUM...")
170
+ tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
171
+ model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
172
+
173
+ # Load and combine datasets
174
+ cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]") # 5K samples
175
+ xsum = load_dataset("xsum", split="train[:5000]") # 5K samples
176
+ combined_dataset = concatenate_datasets([cnn_dm, xsum])
177
+
178
+ def preprocess_function(examples):
179
+ inputs = tokenizer(examples["article"] if "article" in examples else examples["document"],
180
+ max_length=512, truncation=True, padding="max_length")
181
+ targets = tokenizer(examples["highlights"] if "highlights" in examples else examples["summary"],
182
+ max_length=400, truncation=True, padding="max_length")
183
+ inputs["labels"] = targets["input_ids"]
184
+ return inputs
185
+
186
+ tokenized_dataset = combined_dataset.map(preprocess_function, batched=True)
187
+ train_dataset = tokenized_dataset.select(range(8000)) # 80%
188
+ eval_dataset = tokenized_dataset.select(range(8000, 10000)) # 20%
189
+
190
+ training_args = TrainingArguments(
191
+ output_dir="./pegasus_finetune",
192
+ num_train_epochs=3, # Increased for better fine-tuning
193
+ per_device_train_batch_size=1,
194
+ per_device_eval_batch_size=1,
195
+ warmup_steps=500,
196
+ weight_decay=0.01,
197
+ logging_dir="./logs",
198
+ logging_steps=10,
199
+ eval_strategy="epoch",
200
+ save_strategy="epoch",
201
+ load_best_model_at_end=True,
202
+ )
203
+
204
+ trainer = Trainer(
205
+ model=model,
206
+ args=training_args,
207
+ train_dataset=train_dataset,
208
+ eval_dataset=eval_dataset,
209
+ )
210
+
211
+ trainer.train()
212
+ trainer.save_model(PEGASUS_MODEL_DIR)
213
+ tokenizer.save_pretrained(PEGASUS_MODEL_DIR)
214
+ print(f"Fine-tuned Pegasus saved to {PEGASUS_MODEL_DIR}")
215
+
216
+ return tokenizer, model
217
 
218
+ # BERT Fine-Tuning
219
+ def load_or_finetune_bert():
220
+ if os.path.exists(BERT_MODEL_DIR):
221
+ print("Loading fine-tuned BERT model...")
222
+ tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_DIR)
223
+ model = BertForSequenceClassification.from_pretrained(BERT_MODEL_DIR, num_labels=2)
224
+ else:
225
+ print("Fine-tuning BERT on CNN/Daily Mail for extractive summarization...")
226
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
227
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
228
+
229
+ # Load dataset and preprocess for sentence classification
230
+ cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]")
231
+
232
+ def preprocess_for_extractive(examples):
233
+ sentences = []
234
+ labels = []
235
+ for article, highlights in zip(examples["article"], examples["highlights"]):
236
+ article_sents = article.split(". ")
237
+ highlight_sents = highlights.split(". ")
238
+ for sent in article_sents:
239
+ if sent.strip():
240
+ # Label as 1 if sentence is similar to any highlight, else 0
241
+ is_summary = any(sent.strip() in h for h in highlight_sents)
242
+ sentences.append(sent)
243
+ labels.append(1 if is_summary else 0)
244
+ return {"sentence": sentences, "label": labels}
245
+
246
+ dataset = cnn_dm.map(preprocess_for_extractive, batched=True, remove_columns=["article", "highlights", "id"])
247
+ tokenized_dataset = dataset.map(
248
+ lambda x: tokenizer(x["sentence"], max_length=512, truncation=True, padding="max_length"),
249
+ batched=True
250
+ )
251
+ tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
252
+ train_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset))))
253
+ eval_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset)), len(tokenized_dataset)))
254
+
255
+ training_args = TrainingArguments(
256
+ output_dir="./bert_finetune",
257
+ num_train_epochs=3,
258
+ per_device_train_batch_size=8,
259
+ per_device_eval_batch_size=8,
260
+ warmup_steps=500,
261
+ weight_decay=0.01,
262
+ logging_dir="./logs",
263
+ logging_steps=10,
264
+ eval_strategy="epoch",
265
+ save_strategy="epoch",
266
+ load_best_model_at_end=True,
267
+ )
268
+
269
+ trainer = Trainer(
270
+ model=model,
271
+ args=training_args,
272
+ train_dataset=train_dataset,
273
+ eval_dataset=eval_dataset,
274
+ )
275
+
276
+ trainer.train()
277
+ trainer.save_model(BERT_MODEL_DIR)
278
+ tokenizer.save_pretrained(BERT_MODEL_DIR)
279
+ print(f"Fine-tuned BERT saved to {BERT_MODEL_DIR}")
280
+
281
+ return tokenizer, model
282
+
283
+ # LegalBERT Fine-Tuning
284
+ def load_or_finetune_legalbert():
285
+ if os.path.exists(LEGALBERT_MODEL_DIR):
286
+ print("Loading fine-tuned LegalBERT model...")
287
+ tokenizer = BertTokenizer.from_pretrained(LEGALBERT_MODEL_DIR)
288
+ model = BertForSequenceClassification.from_pretrained(LEGALBERT_MODEL_DIR, num_labels=2)
289
+ else:
290
+ print("Fine-tuning LegalBERT on Billsum for extractive summarization...")
291
+ tokenizer = BertTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
292
+ model = BertForSequenceClassification.from_pretrained("nlpaueb/legal-bert-base-uncased", num_labels=2)
293
+
294
+ # Load dataset
295
+ billsum = load_dataset("billsum", split="train[:5000]")
296
+
297
+ def preprocess_for_extractive(examples):
298
+ sentences = []
299
+ labels = []
300
+ for text, summary in zip(examples["text"], examples["summary"]):
301
+ text_sents = text.split(". ")
302
+ summary_sents = summary.split(". ")
303
+ for sent in text_sents:
304
+ if sent.strip():
305
+ is_summary = any(sent.strip() in s for s in summary_sents)
306
+ sentences.append(sent)
307
+ labels.append(1 if is_summary else 0)
308
+ return {"sentence": sentences, "label": labels}
309
+
310
+ dataset = billsum.map(preprocess_for_extractive, batched=True, remove_columns=["text", "summary", "title"])
311
+ tokenized_dataset = dataset.map(
312
+ lambda x: tokenizer(x["sentence"], max_length=512, truncation=True, padding="max_length"),
313
+ batched=True
314
+ )
315
+ tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
316
+ train_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset))))
317
+ eval_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset)), len(tokenized_dataset)))
318
+
319
+ training_args = TrainingArguments(
320
+ output_dir="./legalbert_finetune",
321
+ num_train_epochs=3,
322
+ per_device_train_batch_size=8,
323
+ per_device_eval_batch_size=8,
324
+ warmup_steps=500,
325
+ weight_decay=0.01,
326
+ logging_dir="./logs",
327
+ logging_steps=10,
328
+ eval_strategy="epoch",
329
+ save_strategy="epoch",
330
+ load_best_model_at_end=True,
331
+ )
332
+
333
+ trainer = Trainer(
334
+ model=model,
335
+ args=training_args,
336
+ train_dataset=train_dataset,
337
+ eval_dataset=eval_dataset,
338
+ )
339
+
340
+ trainer.train()
341
+ trainer.save_model(LEGALBERT_MODEL_DIR)
342
+ tokenizer.save_pretrained(LEGALBERT_MODEL_DIR)
343
+ print(f"Fine-tuned LegalBERT saved to {LEGALBERT_MODEL_DIR}")
344
+
345
+ return tokenizer, model
346
+
347
+ # Load models
348
+ pegasus_tokenizer, pegasus_model = load_or_finetune_pegasus()
349
+ bert_tokenizer, bert_model = load_or_finetune_bert()
350
+ legalbert_tokenizer, legalbert_model = load_or_finetune_legalbert()
351
+
352
+ def extract_text_from_pdf(file_path):
353
  text = ""
354
+ with pdfplumber.open(file_path) as pdf:
355
+ for page in pdf.pages:
356
+ text += page.extract_text() or ""
357
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
 
359
  def extract_text_from_image(file_path):
360
+ image = Image.open(file_path)
361
+ text = pytesseract.image_to_string(image)
362
+ return text
 
 
 
 
 
363
 
364
+ def choose_model(text):
365
+ legal_keywords = ["court", "legal", "law", "judgment", "contract", "statute", "case"]
366
+ tfidf = TfidfVectorizer(vocabulary=legal_keywords)
367
+ tfidf_matrix = tfidf.fit_transform([text.lower()])
368
+ score = np.sum(tfidf_matrix.toarray())
369
+ if score > 0.1:
370
+ return "legalbert"
371
+ elif len(text.split()) > 50:
372
+ return "pegasus"
373
+ else:
374
+ return "bert"
375
+
376
+ def summarize_with_pegasus(text):
377
+ inputs = pegasus_tokenizer(text, truncation=True, padding="longest", return_tensors="pt", max_length=512)
378
+ summary_ids = pegasus_model.generate(
379
+ inputs["input_ids"],
380
+ max_length=400, min_length=80, length_penalty=1.5, num_beams=4
381
+ )
382
+ return pegasus_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
383
+
384
+ def summarize_with_bert(text):
385
+ sentences = text.split(". ")
386
+ if len(sentences) < 6: # Ensure enough for 5 sentences
387
+ return text
388
+ inputs = bert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
389
+ with torch.no_grad():
390
+ outputs = bert_model(**inputs)
391
+ logits = outputs.logits
392
+ probs = torch.softmax(logits, dim=1)[:, 1] # Probability of being a summary sentence
393
+ key_sentence_idx = probs.argsort(descending=True)[:5] # Top 5 sentences
394
+ return ". ".join([sentences[idx] for idx in key_sentence_idx if sentences[idx].strip()])
395
+
396
+ def summarize_with_legalbert(text):
397
+ sentences = text.split(". ")
398
+ if len(sentences) < 6:
399
+ return text
400
+ inputs = legalbert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
401
+ with torch.no_grad():
402
+ outputs = legalbert_model(**inputs)
403
+ logits = outputs.logits
404
+ probs = torch.softmax(logits, dim=1)[:, 1]
405
+ key_sentence_idx = probs.argsort(descending=True)[:5]
406
+ return ". ".join([sentences[idx] for idx in key_sentence_idx if sentences[idx].strip()])
407
 
408
  @app.route('/summarize', methods=['POST'])
409
  def summarize_document():
410
  if 'file' not in request.files:
 
411
  return jsonify({"error": "No file uploaded"}), 400
412
 
413
  file = request.files['file']
414
  filename = file.filename
415
+ file.seek(0, os.SEEK_END)
416
+ file_size = file.tell()
417
+ if file_size > MAX_FILE_SIZE:
418
+ return jsonify({"error": f"File size exceeds {MAX_FILE_SIZE // (1024 * 1024)} MB"}), 413
419
+ file.seek(0)
420
+ file_path = os.path.join(UPLOAD_FOLDER, filename)
421
  try:
422
  file.save(file_path)
423
+ except Exception as e:
424
+ return jsonify({"error": f"Failed to save file: {str(e)}"}), 500
425
+
426
+ try:
427
+ if filename.endswith('.pdf'):
428
+ text = extract_text_from_pdf(file_path)
429
+ elif filename.endswith(('.png', '.jpeg', '.jpg')):
430
  text = extract_text_from_image(file_path)
431
  else:
432
+ os.remove(file_path)
433
+ return jsonify({"error": "Unsupported file format."}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  except Exception as e:
435
+ os.remove(file_path)
436
+ return jsonify({"error": f"Text extraction failed: {str(e)}"}), 500
437
+
438
+ if not text.strip():
439
+ os.remove(file_path)
440
+ return jsonify({"error": "No text extracted"}), 400
441
+
442
+ try:
443
+ model = choose_model(text)
444
+ if model == "pegasus":
445
+ summary = summarize_with_pegasus(text)
446
+ elif model == "bert":
447
+ summary = summarize_with_bert(text)
448
+ elif model == "legalbert":
449
+ summary = summarize_with_legalbert(text)
450
+ except Exception as e:
451
+ os.remove(file_path)
452
+ return jsonify({"error": f"Summarization failed: {str(e)}"}), 500
453
+
454
+ os.remove(file_path)
455
+ return jsonify({"model_used": model, "summary": summary})
456
 
457
  if __name__ == '__main__':
458
+ app.run(debug=True, host='0.0.0.0', port=5000)