import os import io import logging from flask import Flask, request, jsonify from werkzeug.utils import secure_filename from PyPDF2 import PdfReader from docx import Document from pptx import Presentation from transformers import T5Tokenizer, T5ForConditionalGeneration # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize Flask app app = Flask(__name__) # Set Hugging Face cache directory os.environ["HF_HOME"] = "/app/hf_cache" # Load T5 model and tokenizer logger.info("Loading T5-Base model...") try: tokenizer = T5Tokenizer.from_pretrained("t5-base") model = T5ForConditionalGeneration.from_pretrained("t5-base") logger.info("T5-Base model loaded successfully.") except Exception as e: logger.error(f"Failed to load T5-Base: {str(e)}") raise ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "txt"} def allowed_file(filename): """Check if the uploaded file has an allowed extension.""" return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS def summarize_text(text, max_length=150, min_length=30): """Summarize text using T5-Base.""" try: if not text.strip(): return "No text found in the document to summarize." input_text = "summarize: " + text inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) summary_ids = model.generate( inputs["input_ids"], max_length=max_length, min_length=min_length, length_penalty=2.0, num_beams=4, early_stopping=True ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return summary except Exception as e: logger.error(f"Error in T5 summarization: {str(e)}") return "Error summarizing text." @app.route("/", methods=["GET"]) def index(): """Root endpoint.""" logger.info("Root endpoint accessed.") return "Document Summarizer API with T5-Base is running! Use /summarize endpoint for POST requests." @app.route("/summarize", methods=["POST"]) def summarize(): """Handle file uploads and summarization.""" logger.info("Summarize endpoint called.") if "file" not in request.files: logger.error("No file uploaded.") return jsonify({"error": "No file uploaded"}), 400 file = request.files["file"] if file.filename == "": logger.error("No file selected.") return jsonify({"error": "No selected file"}), 400 if not allowed_file(file.filename): logger.error(f"Unsupported file format: {file.filename}") return jsonify({"error": "Unsupported file format"}), 400 filename = secure_filename(file.filename) file_content = file.read() file_ext = filename.rsplit(".", 1)[1].lower() try: if file_ext == "pdf": text = summarize_pdf(file_content) elif file_ext == "docx": text = summarize_docx(file_content) elif file_ext == "pptx": text = summarize_pptx(file_content) elif file_ext == "txt": text = summarize_txt(file_content) else: return jsonify({"error": "Unsupported file format"}), 400 if not text.strip(): return jsonify({"error": "No extractable text found in the document"}), 400 summary = summarize_text(text) logger.info(f"File {filename} summarized successfully.") return jsonify({"filename": filename, "summary": summary}) except Exception as e: logger.error(f"Error processing file {filename}: {str(e)}") return jsonify({"error": f"Error processing file: {str(e)}"}), 500 def summarize_pdf(file_content): """Extract text from PDF.""" reader = PdfReader(io.BytesIO(file_content)) text = "\n".join([page.extract_text() or "" for page in reader.pages]) return text.strip() def summarize_docx(file_content): """Extract text from DOCX.""" doc = Document(io.BytesIO(file_content)) text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()]) return text.strip() def summarize_pptx(file_content): """Extract text from PPTX.""" ppt = Presentation(io.BytesIO(file_content)) text = [] for slide in ppt.slides: for shape in slide.shapes: if hasattr(shape, "text") and shape.text.strip(): text.append(shape.text.strip()) return "\n".join(text).strip() def summarize_txt(file_content): """Extract text from TXT file.""" return file_content.decode("utf-8").strip() if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=True)