mike23415's picture
Update app.py
a911ba5 verified
raw
history blame
3.9 kB
import os
import io
import logging
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
from PyPDF2 import PdfReader
from docx import Document
from pptx import Presentation
from transformers import T5Tokenizer, T5ForConditionalGeneration
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# Load T5 model and tokenizer (done once at startup)
logger.info("Loading T5-Base model...")
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")
logger.info("T5-Base model loaded successfully.")
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "txt"}
def allowed_file(filename):
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
def summarize_text(text, max_length=150, min_length=30):
"""Summarize text using T5-Base."""
try:
# Prepend "summarize: " as required by T5
input_text = "summarize: " + text
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
length_penalty=2.0,
num_beams=4,
early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
logger.error(f"Error in T5 summarization: {str(e)}")
raise
@app.route("/", methods=["GET"])
def index():
logger.info("Root endpoint accessed.")
return "Document Summarizer API with T5-Base is running! Use /summarize endpoint for POST requests."
@app.route("/summarize", methods=["POST"])
def summarize():
logger.info("Summarize endpoint called.")
if "file" not in request.files:
logger.error("No file uploaded.")
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
if file.filename == "":
logger.error("No file selected.")
return jsonify({"error": "No selected file"}), 400
if not allowed_file(file.filename):
logger.error(f"Unsupported file format: {file.filename}")
return jsonify({"error": "Unsupported file format"}), 400
filename = secure_filename(file.filename)
file_content = file.read()
file_ext = filename.rsplit(".", 1)[1].lower()
try:
if file_ext == "pdf":
text = summarize_pdf(file_content)
elif file_ext == "docx":
text = summarize_docx(file_content)
elif file_ext == "pptx":
text = summarize_pptx(file_content)
elif file_ext == "txt":
text = summarize_txt(file_content)
summary = summarize_text(text)
logger.info(f"File {filename} summarized successfully with T5.")
return jsonify({"filename": filename, "summary": summary})
except Exception as e:
logger.error(f"Error processing file {filename}: {str(e)}")
return jsonify({"error": f"Error processing file: {str(e)}"}), 500
def summarize_pdf(file_content):
reader = PdfReader(io.BytesIO(file_content))
text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
return text
def summarize_docx(file_content):
doc = Document(io.BytesIO(file_content))
text = "\n".join([para.text for para in doc.paragraphs])
return text
def summarize_pptx(file_content):
ppt = Presentation(io.BytesIO(file_content))
text = []
for slide in ppt.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text.append(shape.text)
return "\n".join(text)
def summarize_txt(file_content):
return file_content.decode("utf-8")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)