Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import tempfile | |
import shutil | |
import torch | |
import gradio as gr | |
from pathlib import Path | |
from typing import Optional, List, Dict, Any, Union | |
import requests | |
from urllib.parse import urlparse | |
# Docling imports | |
from docling.datamodel.base_models import InputFormat | |
from docling.datamodel.pipeline_options import PdfPipelineOptions, TesseractCliOcrOptions | |
from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, SimplePipeline | |
# LangChain imports | |
from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.schema import Document | |
# Transformers imports for IBM Granite model | |
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Initialize IBM Granite model and tokenizer | |
print("Loading Granite model and tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.2-8b-instruct") | |
model = AutoModelForCausalLM.from_pretrained( | |
"ibm-granite/granite-3.2-8b-instruct", | |
device_map="auto", | |
torch_dtype=torch.bfloat16 | |
) | |
print("Model loaded successfully!") | |
# Helper function to detect document format | |
def get_document_format(file_path) -> InputFormat: | |
"""Determine the document format based on file extension""" | |
try: | |
file_path = str(file_path) | |
extension = os.path.splitext(file_path)[1].lower() | |
format_map = { | |
'.pdf': InputFormat.PDF, | |
'.docx': InputFormat.DOCX, | |
'.doc': InputFormat.DOCX, | |
'.pptx': InputFormat.PPTX, | |
'.html': InputFormat.HTML, | |
'.htm': InputFormat.HTML | |
} | |
return format_map.get(extension, None) | |
except Exception as e: | |
return f"Error in get_document_format: {str(e)}" | |
# Function to convert documents to markdown | |
def convert_document_to_markdown(doc_path) -> str: | |
"""Convert document to markdown using simplified pipeline""" | |
try: | |
# Convert to absolute path string | |
input_path = os.path.abspath(str(doc_path)) | |
print(f"Converting document: {doc_path}") | |
# Create temporary directory for processing | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Copy input file to temp directory | |
temp_input = os.path.join(temp_dir, os.path.basename(input_path)) | |
shutil.copy2(input_path, temp_input) | |
# Configure pipeline options | |
pipeline_options = PdfPipelineOptions() | |
pipeline_options.do_ocr = False # Disable OCR temporarily | |
pipeline_options.do_table_structure = True | |
# Create converter with minimal options | |
converter = DocumentConverter( | |
allowed_formats=[ | |
InputFormat.PDF, | |
InputFormat.DOCX, | |
InputFormat.HTML, | |
InputFormat.PPTX, | |
], | |
format_options={ | |
InputFormat.PDF: PdfFormatOption( | |
pipeline_options=pipeline_options, | |
), | |
InputFormat.DOCX: WordFormatOption( | |
pipeline_cls=SimplePipeline | |
) | |
} | |
) | |
# Convert document | |
print("Starting conversion...") | |
conv_result = converter.convert(temp_input) | |
if not conv_result or not conv_result.document: | |
raise ValueError(f"Failed to convert document: {doc_path}") | |
# Export to markdown | |
print("Exporting to markdown...") | |
md = conv_result.document.export_to_markdown() | |
# Create output path | |
output_dir = os.path.dirname(input_path) | |
base_name = os.path.splitext(os.path.basename(input_path))[0] | |
md_path = os.path.join(output_dir, f"{base_name}_converted.md") | |
# Write markdown file | |
print(f"Writing markdown to: {base_name}_converted.md") | |
with open(md_path, "w", encoding="utf-8") as fp: | |
fp.write(md) | |
return md_path | |
except Exception as e: | |
return f"Error converting document: {str(e)}" | |
# Function to download file from URL | |
def download_file_from_url(url: str) -> Optional[str]: | |
"""Download a file from a URL and save it temporarily""" | |
try: | |
# Parse URL to get filename | |
parsed_url = urlparse(url) | |
filename = os.path.basename(parsed_url.path) | |
if not filename: | |
filename = "downloaded_document" | |
# Add extension based on Content-Type if needed | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
content_type = response.headers.get('Content-Type', '') | |
if 'pdf' in content_type: | |
if not filename.lower().endswith('.pdf'): | |
filename += ".pdf" | |
elif 'word' in content_type or 'docx' in content_type: | |
if not filename.lower().endswith(('.doc', '.docx')): | |
filename += ".docx" | |
elif 'powerpoint' in content_type or 'pptx' in content_type: | |
if not filename.lower().endswith(('.ppt', '.pptx')): | |
filename += ".pptx" | |
elif 'html' in content_type: | |
if not filename.lower().endswith(('.html', '.htm')): | |
filename += ".html" | |
# Create a temporary file | |
temp_dir = tempfile.gettempdir() | |
file_path = os.path.join(temp_dir, filename) | |
# Save the file | |
with open(file_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
return file_path | |
except Exception as e: | |
print(f"Error downloading file: {str(e)}") | |
return None | |
# Function to generate a summary using the IBM Granite model | |
def generate_summary(chunks: List[Document], model, tokenizer, summary_type="abstractive", detail_level="medium", length="medium"): | |
"""Generate a summary from document chunks using the IBM Granite model""" | |
# Concatenate the retrieved chunks | |
combined_text = " ".join([chunk.page_content for chunk in chunks]) | |
# Create a prompt based on the summary parameters | |
if summary_type == "extractive": | |
summary_instruction = "Extract the key sentences from the text to create a summary." | |
else: # abstractive | |
summary_instruction = "Generate a comprehensive summary in your own words." | |
if detail_level == "high": | |
detail_instruction = "Include specific details and examples." | |
elif detail_level == "medium": | |
detail_instruction = "Balance key points with some supporting details." | |
else: # low | |
detail_instruction = "Focus only on the main points and key takeaways." | |
if length == "short": | |
length_instruction = "Keep the summary concise and brief." | |
elif length == "medium": | |
length_instruction = "Create a moderate-length summary." | |
else: # long | |
length_instruction = "Provide a comprehensive, detailed summary." | |
# Construct the full prompt | |
prompt = f"""<instruction> | |
You are a document summarization assistant. Based on the following text, {summary_instruction} {detail_instruction} {length_instruction} | |
</instruction> | |
<text> | |
{combined_text} | |
</text> | |
""" | |
# Generate the summary using the IBM Granite model | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
output = model.generate( | |
**inputs, | |
max_new_tokens=1024, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
# Decode and return the generated summary | |
summary = tokenizer.decode(output[0], skip_special_tokens=True) | |
# Extract just the generated response (after the prompt) | |
summary = summary[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):] | |
return summary.strip() | |
# Function to summarize a full document | |
def summarize_full_document(retriever, model, tokenizer, summary_params, chunk_size=8): | |
"""Summarize an entire document by processing all chunks""" | |
all_chunks = [] | |
# Get all documents from the vector store | |
for i in range(0, len(retriever.vectorstore.index_to_docstore_id), chunk_size): | |
batch_ids = list(retriever.vectorstore.index_to_docstore_id.values())[i:i+chunk_size] | |
batch_chunks = [retriever.vectorstore.docstore.search(doc_id) for doc_id in batch_ids] | |
all_chunks.extend(batch_chunks) | |
# Process chunks in manageable batches if needed | |
summaries = [] | |
for i in range(0, len(all_chunks), chunk_size): | |
batch = all_chunks[i:i+chunk_size] | |
summary = generate_summary( | |
batch, | |
model, | |
tokenizer, | |
summary_type=summary_params.get("summary_type", "abstractive"), | |
detail_level=summary_params.get("detail_level", "medium"), | |
length=summary_params.get("length", "medium") | |
) | |
summaries.append(summary) | |
# Create final summary from batch summaries if needed | |
if len(summaries) > 1: | |
final_summary = generate_summary( | |
[Document(page_content=s) for s in summaries], | |
model, | |
tokenizer, | |
summary_type=summary_params.get("summary_type", "abstractive"), | |
detail_level=summary_params.get("detail_level", "medium"), | |
length=summary_params.get("length", "medium") | |
) | |
return final_summary | |
else: | |
return summaries[0] if summaries else "No content to summarize" | |
# Main function to process document and generate summary | |
def process_document( | |
file_obj: Optional[Union[str, tempfile._TemporaryFileWrapper]] = None, | |
url: Optional[str] = None, | |
summary_type: str = "abstractive", | |
detail_level: str = "medium", | |
length: str = "medium", | |
progress=gr.Progress() | |
): | |
"""Process a document file or URL and generate a summary""" | |
try: | |
# Process input source (file or URL) | |
document_path = None | |
if file_obj: | |
document_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj) | |
elif url and url.strip(): | |
progress(0.2, "Downloading document from URL...") | |
document_path = download_file_from_url(url.strip()) | |
if not document_path: | |
return "Failed to download document from URL. Please check the URL and try again." | |
else: | |
return "Please provide either a file or a URL to summarize." | |
# Convert document to markdown | |
progress(0.3, "Converting document to markdown...") | |
markdown_path = convert_document_to_markdown(document_path) | |
if markdown_path.startswith("Error"): | |
return markdown_path | |
# Load and split the document | |
progress(0.4, "Loading and splitting document...") | |
loader = UnstructuredMarkdownLoader(str(markdown_path)) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
length_function=len | |
) | |
texts = text_splitter.split_documents(documents) | |
if not texts: | |
return "No text could be extracted from the document." | |
# Create embeddings and vector store | |
progress(0.6, "Creating document embeddings...") | |
embeddings = HuggingFaceEmbeddings( | |
model_name="nomic-ai/nomic-embed-text-v1", | |
model_kwargs={'trust_remote_code': True} | |
) | |
vectorstore = FAISS.from_documents(texts, embeddings) | |
# Create retriever | |
retriever = vectorstore.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 4} | |
) | |
# Generate summary | |
progress(0.8, "Generating summary...") | |
summary_params = { | |
"summary_type": summary_type, | |
"detail_level": detail_level, | |
"length": length | |
} | |
summary = summarize_full_document(retriever, model, tokenizer, summary_params) | |
progress(1.0, "Summary complete!") | |
return summary | |
except Exception as e: | |
return f"Error processing document: {str(e)}" | |
# Create Gradio interface | |
def create_gradio_interface(): | |
"""Create and launch the Gradio interface""" | |
with gr.Blocks(title="Document Summarizer") as app: | |
gr.Markdown("# Document Summarizer") | |
gr.Markdown("Upload a document or provide a URL to generate a summary.") | |
with gr.Row(): | |
with gr.Column(): | |
file_input = gr.File(label="Upload Document (PDF, DOCX, PPTX, HTML)") | |
url_input = gr.Textbox(label="Or enter document URL") | |
with gr.Row(): | |
with gr.Column(): | |
summary_type = gr.Radio( | |
choices=["extractive", "abstractive"], | |
value="abstractive", | |
label="Summary Type" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
detail_level = gr.Radio( | |
choices=["low", "medium", "high"], | |
value="medium", | |
label="Level of Detail" | |
) | |
with gr.Column(): | |
length = gr.Radio( | |
choices=["short", "medium", "long"], | |
value="medium", | |
label="Summary Length" | |
) | |
submit_btn = gr.Button("Generate Summary", variant="primary") | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Summary Result", | |
lines=15, | |
max_lines=30 | |
) | |
submit_btn.click( | |
fn=process_document, | |
inputs=[file_input, url_input, summary_type, detail_level, length], | |
outputs=output | |
) | |
gr.Markdown(""" | |
## How to use: | |
1. Upload a document (PDF, DOCX, PPTX, HTML) or provide a URL | |
2. Choose your preferred summary parameters: | |
- Summary Type: Extractive (pulls key sentences) or Abstractive (generates new text) | |
- Level of Detail: Low, Medium, or High | |
- Summary Length: Short, Medium, or Long | |
3. Click "Generate Summary" to process the document | |
""") | |
return app | |
# Launch the application | |
if __name__ == "__main__": | |
app = create_gradio_interface() | |
app.launch() | |