Spaces:
Sleeping
Sleeping
File size: 3,896 Bytes
764d4f7 2a3fae3 80d0b8a b7db40a 764d4f7 2a3fae3 764d4f7 2a3fae3 a911ba5 524f780 a911ba5 80d0b8a 764d4f7 b7db40a a911ba5 d2d0219 2a3fae3 d2d0219 2a3fae3 d2d0219 a911ba5 92d0377 98e82be 80d0b8a a911ba5 98e82be 3b4df89 2a3fae3 80d0b8a 764d4f7 80d0b8a d2d0219 764d4f7 80d0b8a 764d4f7 2a3fae3 80d0b8a 2a3fae3 764d4f7 2a3fae3 98e82be a911ba5 98e82be a911ba5 98e82be a911ba5 98e82be a911ba5 92d0377 98e82be 80d0b8a 98e82be 764d4f7 92d0377 2a3fae3 92d0377 a911ba5 d2d0219 92d0377 2a3fae3 92d0377 a911ba5 d2d0219 92d0377 2a3fae3 98e82be 92d0377 98e82be a911ba5 b7db40a 92d0377 a911ba5 53425a8 9fd7d89 92d0377 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import io
import logging
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
from PyPDF2 import PdfReader
from docx import Document
from pptx import Presentation
from transformers import T5Tokenizer, T5ForConditionalGeneration
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# Load T5 model and tokenizer (done once at startup)
logger.info("Loading T5-Base model...")
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")
logger.info("T5-Base model loaded successfully.")
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "txt"}
def allowed_file(filename):
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
def summarize_text(text, max_length=150, min_length=30):
"""Summarize text using T5-Base."""
try:
# Prepend "summarize: " as required by T5
input_text = "summarize: " + text
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
length_penalty=2.0,
num_beams=4,
early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
logger.error(f"Error in T5 summarization: {str(e)}")
raise
@app.route("/", methods=["GET"])
def index():
logger.info("Root endpoint accessed.")
return "Document Summarizer API with T5-Base is running! Use /summarize endpoint for POST requests."
@app.route("/summarize", methods=["POST"])
def summarize():
logger.info("Summarize endpoint called.")
if "file" not in request.files:
logger.error("No file uploaded.")
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
if file.filename == "":
logger.error("No file selected.")
return jsonify({"error": "No selected file"}), 400
if not allowed_file(file.filename):
logger.error(f"Unsupported file format: {file.filename}")
return jsonify({"error": "Unsupported file format"}), 400
filename = secure_filename(file.filename)
file_content = file.read()
file_ext = filename.rsplit(".", 1)[1].lower()
try:
if file_ext == "pdf":
text = summarize_pdf(file_content)
elif file_ext == "docx":
text = summarize_docx(file_content)
elif file_ext == "pptx":
text = summarize_pptx(file_content)
elif file_ext == "txt":
text = summarize_txt(file_content)
summary = summarize_text(text)
logger.info(f"File {filename} summarized successfully with T5.")
return jsonify({"filename": filename, "summary": summary})
except Exception as e:
logger.error(f"Error processing file {filename}: {str(e)}")
return jsonify({"error": f"Error processing file: {str(e)}"}), 500
def summarize_pdf(file_content):
reader = PdfReader(io.BytesIO(file_content))
text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
return text
def summarize_docx(file_content):
doc = Document(io.BytesIO(file_content))
text = "\n".join([para.text for para in doc.paragraphs])
return text
def summarize_pptx(file_content):
ppt = Presentation(io.BytesIO(file_content))
text = []
for slide in ppt.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text.append(shape.text)
return "\n".join(text)
def summarize_txt(file_content):
return file_content.decode("utf-8")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True) |