Spaces:
Sleeping
Sleeping
import os | |
import io | |
import re | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from werkzeug.utils import secure_filename | |
from PyPDF2 import PdfReader | |
from docx import Document | |
from pptx import Presentation | |
import nltk | |
import string | |
from nltk.corpus import stopwords | |
from nltk.tokenize import sent_tokenize, word_tokenize | |
from nltk.probability import FreqDist | |
from heapq import nlargest | |
from collections import defaultdict | |
app = Flask(__name__) | |
CORS(app) # Enable CORS for all routes | |
# Set NLTK data path to a directory included in the project | |
nltk_data_dir = os.path.join(os.getcwd(), 'nltk_data') | |
os.makedirs(nltk_data_dir, exist_ok=True) | |
nltk.data.path.append(nltk_data_dir) | |
# Ensure NLTK data is available (pre-downloaded) | |
try: | |
stopwords.words('english') # Test if stopwords are accessible | |
except LookupError: | |
print("NLTK data not found. Please ensure 'punkt' and 'stopwords' are pre-downloaded in 'nltk_data'.") | |
# Fallback will be used if this fails | |
# Allowed file extensions | |
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "txt"} | |
def allowed_file(filename): | |
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS | |
def index(): | |
return "Document Summarizer API is running! Use /summarize endpoint for POST requests." | |
def summarize(): | |
if "file" not in request.files: | |
return jsonify({"error": "No file uploaded"}), 400 | |
file = request.files["file"] | |
if file.filename == "": | |
return jsonify({"error": "No selected file"}), 400 | |
if not allowed_file(file.filename): | |
return jsonify({"error": "Unsupported file format"}), 400 | |
filename = secure_filename(file.filename) | |
file_content = file.read() | |
# Process file based on type | |
text = None | |
file_ext = filename.rsplit(".", 1)[1].lower() | |
try: | |
if file_ext == "pdf": | |
text = extract_text_from_pdf(file_content) | |
elif file_ext == "docx": | |
text = extract_text_from_docx(file_content) | |
elif file_ext == "pptx": | |
text = extract_text_from_pptx(file_content) | |
elif file_ext == "txt": | |
text = extract_text_from_txt(file_content) | |
# Generate a summary of the text | |
try: | |
summary = generate_summary(text) | |
except LookupError as e: | |
print(f"NLTK summarization failed: {e}. Using fallback.") | |
summary = simple_summarize(text) | |
except Exception as e: | |
print(f"Summarization error: {e}") | |
summary = text[:1000] + "..." if len(text) > 1000 else text | |
# Include metadata | |
word_count = len(text.split()) | |
return jsonify({ | |
"filename": filename, | |
"summary": summary, | |
"original_word_count": word_count, | |
"summary_word_count": len(summary.split()) if summary else 0 | |
}) | |
except Exception as e: | |
return jsonify({"error": f"Error processing file: {str(e)}"}), 500 | |
# Text extraction functions | |
def extract_text_from_pdf(file_content): | |
reader = PdfReader(io.BytesIO(file_content)) | |
text = "" | |
for page in reader.pages: | |
page_text = page.extract_text() | |
if page_text: | |
text += page_text + "\n\n" | |
return clean_text(text) | |
def extract_text_from_docx(file_content): | |
doc = Document(io.BytesIO(file_content)) | |
text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()]) | |
return clean_text(text) | |
def extract_text_from_pptx(file_content): | |
ppt = Presentation(io.BytesIO(file_content)) | |
text = [] | |
for slide in ppt.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text") and shape.text.strip(): | |
text.append(shape.text) | |
return clean_text("\n".join(text)) | |
def extract_text_from_txt(file_content): | |
text = file_content.decode("utf-8", errors="ignore") | |
return clean_text(text) | |
def clean_text(text): | |
text = re.sub(r'\s+', ' ', text) | |
text = re.sub(r'[^\w\s\.\,\!\?\:\;]', '', text) | |
return text.strip() | |
def generate_summary(text, sentence_count=5): | |
if len(text.split()) < 100: | |
return text | |
sentences = sent_tokenize(text) | |
if len(sentences) <= sentence_count: | |
return text | |
clean_sentences = [s.translate(str.maketrans('', '', string.punctuation)).lower() for s in sentences] | |
stop_words = set(stopwords.words('english')) | |
word_frequencies = defaultdict(int) | |
for sentence in clean_sentences: | |
for word in word_tokenize(sentence): | |
if word not in stop_words: | |
word_frequencies[word] += 1 | |
max_frequency = max(word_frequencies.values()) if word_frequencies else 1 | |
for word in word_frequencies: | |
word_frequencies[word] = word_frequencies[word] / max_frequency | |
sentence_scores = defaultdict(int) | |
for i, sentence in enumerate(clean_sentences): | |
for word in word_tokenize(sentence): | |
if word in word_frequencies: | |
sentence_scores[i] += word_frequencies[word] | |
top_indices = nlargest(sentence_count, sentence_scores, key=sentence_scores.get) | |
top_indices.sort() | |
return ' '.join([sentences[i] for i in top_indices]) | |
def simple_summarize(text, max_chars=1000): | |
paragraphs = text.split('\n\n') | |
base_summary = ' '.join(paragraphs[:3]) | |
if len(text) <= max_chars: | |
return text | |
if len(base_summary) < max_chars: | |
remaining_text = ' '.join(paragraphs[3:]) | |
sentences = re.split(r'(?<=[.!?])\s+', remaining_text) | |
for sentence in sentences: | |
if len(base_summary) + len(sentence) + 1 <= max_chars: | |
base_summary += ' ' + sentence | |
else: | |
break | |
if len(base_summary) > max_chars: | |
base_summary = base_summary[:max_chars] + "..." | |
return base_summary | |
if __name__ == "__main__": | |
# For local testing only | |
app.run(host="0.0.0.0", port=7860) |