Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import tempfile | |
import shutil | |
import torch | |
import gradio as gr | |
from pathlib import Path | |
from typing import Optional, List, Union | |
import gc | |
import time | |
# Docling imports | |
from docling.datamodel.base_models import InputFormat | |
from docling.datamodel.pipeline_options import PdfPipelineOptions | |
from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, SimplePipeline | |
# LangChain imports | |
from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.schema import Document | |
# Transformers imports for IBM Granite model | |
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Initialize IBM Granite model and tokenizer | |
print("Loading Granite model and tokenizer...") | |
model_name = "ibm-granite/granite-3.3-8b-instruct" | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load model with optimization for GPU | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
load_in_8bit=True # Use 8-bit quantization for memory efficiency | |
) | |
print("Model loaded successfully!") | |
# Helper function to detect document format | |
def get_document_format(file_path) -> Optional[InputFormat]: | |
"""Determine the document format based on file extension""" | |
try: | |
file_path = str(file_path) | |
extension = os.path.splitext(file_path)[1].lower() | |
format_map = { | |
'.pdf': InputFormat.PDF, | |
'.docx': InputFormat.DOCX, | |
'.doc': InputFormat.DOCX, | |
'.pptx': InputFormat.PPTX, | |
'.html': InputFormat.HTML, | |
'.htm': InputFormat.HTML | |
} | |
return format_map.get(extension) | |
except Exception as e: | |
print(f"Error in get_document_format: {str(e)}") | |
return None | |
# Function to convert documents to markdown | |
def convert_document_to_markdown(doc_path) -> str: | |
"""Convert document to markdown using simplified pipeline""" | |
try: | |
# Convert to absolute path string | |
input_path = os.path.abspath(str(doc_path)) | |
print(f"Converting document: {doc_path}") | |
# Create temporary directory for processing | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Copy input file to temp directory | |
temp_input = os.path.join(temp_dir, os.path.basename(input_path)) | |
shutil.copy2(input_path, temp_input) | |
# Configure pipeline options | |
pipeline_options = PdfPipelineOptions() | |
pipeline_options.do_ocr = False # Disable OCR for performance | |
pipeline_options.do_table_structure = True | |
# Create converter with optimized options | |
converter = DocumentConverter( | |
allowed_formats=[ | |
InputFormat.PDF, | |
InputFormat.DOCX, | |
InputFormat.HTML, | |
InputFormat.PPTX, | |
], | |
format_options={ | |
InputFormat.PDF: PdfFormatOption( | |
pipeline_options=pipeline_options, | |
), | |
InputFormat.DOCX: WordFormatOption( | |
pipeline_cls=SimplePipeline | |
) | |
} | |
) | |
# Convert document | |
print("Starting conversion...") | |
conv_result = converter.convert(temp_input) | |
if not conv_result or not conv_result.document: | |
raise ValueError(f"Failed to convert document: {doc_path}") | |
# Export to markdown | |
print("Exporting to markdown...") | |
md = conv_result.document.export_to_markdown() | |
# Create output path | |
output_dir = os.path.dirname(input_path) | |
base_name = os.path.splitext(os.path.basename(input_path))[0] | |
md_path = os.path.join(output_dir, f"{base_name}_converted.md") | |
# Write markdown file | |
with open(md_path, "w", encoding="utf-8") as fp: | |
fp.write(md) | |
return md_path | |
except Exception as e: | |
return f"Error converting document: {str(e)}" | |
# Function to generate a summary using the IBM Granite model | |
def generate_summary(chunks: List[Document], length_type="sentences", length_count=3): | |
"""Generate a summary from document chunks using the IBM Granite model | |
Args: | |
chunks: List of document chunks to summarize | |
length_type: Either "sentences" or "paragraphs" | |
length_count: Number of sentences (1-10) or paragraphs (1-3) | |
""" | |
# Concatenate the retrieved chunks | |
combined_text = " ".join([chunk.page_content for chunk in chunks]) | |
# Construct length instruction based on type and count | |
if length_type == "sentences": | |
length_instruction = f"Summarize the following text in {length_count} sentence{'s' if length_count > 1 else ''}." | |
else: # paragraphs | |
length_instruction = f"Summarize the following text in {length_count} paragraph{'s' if length_count > 1 else ''}." | |
# Construct the prompt | |
prompt = f"""<instruction> | |
Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a helpful AI assistant. {length_instruction} Your response should only include the answer. Do not provide any further explanation. | |
</instruction> | |
<text> | |
{combined_text} | |
</text> | |
""" | |
# Calculate appropriate max_new_tokens based on length requirements | |
# Approximate tokens: ~15 tokens per sentence, ~75 tokens per paragraph | |
if length_type == "sentences": | |
max_tokens = length_count * 20 # Slightly more than needed for flexibility | |
else: # paragraphs | |
max_tokens = length_count * 100 # Slightly more than needed for flexibility | |
# Ensure minimum tokens and add buffer | |
max_tokens = max(100, min(1000, max_tokens + 50)) | |
# Generate the summary using the IBM Granite model | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
output = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
# Decode and return the generated summary | |
summary = tokenizer.decode(output[0], skip_special_tokens=True) | |
# Extract just the generated response (after the prompt) | |
summary = summary[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):] | |
return summary.strip() | |
# Function to process document chunks efficiently | |
def process_document_chunks(texts, batch_size=8): | |
"""Process document chunks in efficient batches""" | |
try: | |
# Create embeddings with optimized settings | |
embeddings = HuggingFaceEmbeddings( | |
model_name="nomic-ai/nomic-embed-text-v1", | |
model_kwargs={'trust_remote_code': True} | |
) | |
# Create vector store more efficiently | |
vectorstore = FAISS.from_documents( | |
texts, | |
embeddings, | |
# Add distance function for better retrieval | |
distance_strategy="cosine" | |
) | |
return vectorstore | |
except Exception as e: | |
print(f"Error in document processing: {str(e)}") | |
# Fallback to basic processing if optimization fails | |
return FAISS.from_documents(texts, embeddings) | |
# Main function to process document and generate summary | |
def process_document( | |
file_obj: Optional[Union[str, tempfile._TemporaryFileWrapper]] = None, | |
length_type: str = "sentences", | |
length_count: int = 3, | |
progress=gr.Progress() | |
): | |
"""Process a document file and generate a summary""" | |
try: | |
# Process input file | |
if not file_obj: | |
return "Please provide a file to summarize." | |
document_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj) | |
# Validate document format | |
format_type = get_document_format(document_path) | |
if not format_type: | |
return "Unsupported file format. Please upload a PDF, DOCX, PPTX, or HTML file." | |
# Convert document to markdown | |
progress(0.3, "Converting document to markdown...") | |
markdown_path = convert_document_to_markdown(document_path) | |
if markdown_path.startswith("Error"): | |
return markdown_path | |
# Load and split the document | |
progress(0.4, "Loading and splitting document...") | |
loader = UnstructuredMarkdownLoader(str(markdown_path)) | |
documents = loader.load() | |
# Optimize text splitting for better chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, # Larger chunk size for better context | |
chunk_overlap=100, | |
length_function=len, | |
separators=["\n\n", "\n", ".", " ", ""] # Prioritize splitting at paragraph/sentence boundaries | |
) | |
texts = text_splitter.split_documents(documents) | |
if not texts: | |
return "No text could be extracted from the document." | |
# Create vector store with efficient processing | |
progress(0.6, "Processing document content...") | |
vectorstore = process_document_chunks(texts) | |
# Create retriever with optimized settings | |
retriever = vectorstore.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 4} # Number of chunks to retrieve | |
) | |
# Process chunks in smaller batches for memory efficiency | |
progress(0.8, "Generating summary...") | |
all_chunks = [] | |
batch_size = 4 # Smaller batch size for memory efficiency | |
# Get all document chunks | |
doc_ids = list(vectorstore.index_to_docstore_id.values()) | |
# Process in smaller batches | |
for i in range(0, len(doc_ids), batch_size): | |
batch_ids = doc_ids[i:i+batch_size] | |
batch_chunks = [vectorstore.docstore.search(doc_id) for doc_id in batch_ids] | |
all_chunks.extend(batch_chunks) | |
# Force garbage collection to free memory | |
gc.collect() | |
# Sleep briefly to allow memory cleanup | |
time.sleep(0.1) | |
# Generate summary from chunks | |
if len(all_chunks) > 8: | |
# If we have many chunks, process in batches | |
summaries = [] | |
for i in range(0, len(all_chunks), batch_size): | |
batch = all_chunks[i:i+batch_size] | |
summary = generate_summary( | |
batch, | |
length_type=length_type, | |
length_count=max(1, length_count // 2) # Use smaller count for partial summaries | |
) | |
summaries.append(summary) | |
# Force garbage collection | |
gc.collect() | |
# Create final summary from batch summaries | |
final_summary = generate_summary( | |
[Document(page_content=s) for s in summaries], | |
length_type=length_type, | |
length_count=length_count | |
) | |
return final_summary | |
else: | |
# If we have few chunks, generate summary directly | |
return generate_summary( | |
all_chunks, | |
length_type=length_type, | |
length_count=length_count | |
) | |
except Exception as e: | |
return f"Error processing document: {str(e)}" | |
# Create Gradio interface | |
def create_gradio_interface(): | |
"""Create and launch the Gradio interface""" | |
with gr.Blocks(title="Granite Document Summarization") as app: | |
gr.Markdown("# Granite Document Summarization") | |
gr.Markdown("Upload a document to generate a summary.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_input = gr.File( | |
label="Upload Document (PDF, DOCX, PPTX, HTML)", | |
file_types=[".pdf", ".docx", ".doc", ".pptx", ".html", ".htm"] | |
) | |
with gr.Row(): | |
length_type = gr.Radio( | |
choices=["Sentences", "Paragraphs"], | |
value="Sentences", | |
label="Summary Length Type" | |
) | |
with gr.Row(): | |
# Use slider for sentence count (1-10) | |
sentence_count = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
label="Number of Sentences", | |
visible=True | |
) | |
# Use radio for paragraph count (1-3) | |
paragraph_count = gr.Radio( | |
choices=["1", "2", "3"], | |
value="1", | |
label="Number of Paragraphs", | |
visible=False | |
) | |
submit_btn = gr.Button("Summarize", variant="primary") | |
with gr.Column(scale=2): | |
output = gr.TextArea( | |
label="Summary", | |
lines=15, | |
max_lines=30 | |
) | |
# Add interactivity to show/hide appropriate count selector | |
def update_count_visibility(length_type): | |
return { | |
sentence_count: length_type == "Sentences", | |
paragraph_count: length_type == "Paragraphs" | |
} | |
length_type.change( | |
fn=update_count_visibility, | |
inputs=[length_type], | |
outputs=[sentence_count, paragraph_count] | |
) | |
# Function to convert paragraph count from string to int and handle capitalized length types | |
def process_document_wrapper(file, length_type, sentence_count, paragraph_count): | |
# Convert capitalized length_type to lowercase for processing | |
length_type_lower = length_type.lower() | |
if length_type_lower == "sentences": | |
return process_document(file, length_type_lower, int(sentence_count)) | |
else: | |
return process_document(file, length_type_lower, int(paragraph_count)) | |
submit_btn.click( | |
fn=process_document_wrapper, | |
inputs=[file_input, length_type, sentence_count, paragraph_count], | |
outputs=output | |
) | |
gr.Markdown(""" | |
## How to use: | |
1. Upload a document (PDF, DOCX, PPTX, HTML) | |
2. Choose your summary length preference: | |
- Number of Sentences (1-10) | |
- Number of Paragraphs (1-3) | |
3. Click "Summarize" to process the document | |
*This application uses the IBM Granite 3.3-8b model to generate summaries.* | |
""") | |
return app | |
# Launch the application | |
if __name__ == "__main__": | |
app = create_gradio_interface() | |
app.launch() |