|
import os |
|
import tempfile |
|
import shutil |
|
import torch |
|
import gradio as gr |
|
from pathlib import Path |
|
from typing import Optional, List, Union |
|
import gc |
|
import time |
|
|
|
|
|
from docling.datamodel.base_models import InputFormat |
|
from docling.datamodel.pipeline_options import PdfPipelineOptions |
|
from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, SimplePipeline |
|
|
|
|
|
from langchain_community.document_loaders import UnstructuredMarkdownLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.schema import Document |
|
|
|
|
|
import spaces |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
|
|
|
print("Loading Granite model and tokenizer...") |
|
model_name = "ibm-granite/granite-3.3-8b-instruct" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="auto", |
|
quantization_config=quantization_config |
|
) |
|
print("Model loaded successfully!") |
|
|
|
|
|
def get_document_format(file_path) -> Optional[InputFormat]: |
|
"""Determine the document format based on file extension""" |
|
try: |
|
file_path = str(file_path) |
|
extension = os.path.splitext(file_path)[1].lower() |
|
format_map = { |
|
'.pdf': InputFormat.PDF, |
|
'.docx': InputFormat.DOCX, |
|
'.doc': InputFormat.DOCX, |
|
'.pptx': InputFormat.PPTX, |
|
'.html': InputFormat.HTML, |
|
'.htm': InputFormat.HTML |
|
} |
|
return format_map.get(extension) |
|
except Exception as e: |
|
print(f"Error in get_document_format: {str(e)}") |
|
return None |
|
|
|
|
|
def convert_document_to_markdown(doc_path) -> str: |
|
"""Convert document to markdown using simplified pipeline""" |
|
try: |
|
|
|
input_path = os.path.abspath(str(doc_path)) |
|
print(f"Converting document: {doc_path}") |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
temp_input = os.path.join(temp_dir, os.path.basename(input_path)) |
|
shutil.copy2(input_path, temp_input) |
|
|
|
|
|
pipeline_options = PdfPipelineOptions() |
|
pipeline_options.do_ocr = False |
|
pipeline_options.do_table_structure = True |
|
|
|
|
|
converter = DocumentConverter( |
|
allowed_formats=[ |
|
InputFormat.PDF, |
|
InputFormat.DOCX, |
|
InputFormat.HTML, |
|
InputFormat.PPTX, |
|
], |
|
format_options={ |
|
InputFormat.PDF: PdfFormatOption( |
|
pipeline_options=pipeline_options, |
|
), |
|
InputFormat.DOCX: WordFormatOption( |
|
pipeline_cls=SimplePipeline |
|
) |
|
} |
|
) |
|
|
|
|
|
print("Starting conversion...") |
|
conv_result = converter.convert(temp_input) |
|
if not conv_result or not conv_result.document: |
|
raise ValueError(f"Failed to convert document: {doc_path}") |
|
|
|
|
|
print("Exporting to markdown...") |
|
md = conv_result.document.export_to_markdown() |
|
|
|
|
|
output_dir = os.path.dirname(input_path) |
|
base_name = os.path.splitext(os.path.basename(input_path))[0] |
|
md_path = os.path.join(output_dir, f"{base_name}_converted.md") |
|
|
|
|
|
with open(md_path, "w", encoding="utf-8") as fp: |
|
fp.write(md) |
|
return md_path |
|
except Exception as e: |
|
return f"Error converting document: {str(e)}" |
|
|
|
|
|
def clean_and_prepare_text(markdown_path): |
|
"""Load, clean and prepare document text for better processing""" |
|
try: |
|
|
|
loader = UnstructuredMarkdownLoader(str(markdown_path)) |
|
documents = loader.load() |
|
|
|
if not documents: |
|
return None, "No content could be extracted from the document." |
|
|
|
|
|
raw_text = " ".join([doc.page_content for doc in documents]) |
|
|
|
|
|
|
|
text = " ".join(raw_text.split()) |
|
|
|
text = text.replace(" .", ".").replace(" ,", ",") |
|
|
|
for punct in ['.', '!', '?']: |
|
text = text.replace(f"{punct}", f"{punct} ") |
|
|
|
|
|
|
|
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] |
|
|
|
|
|
processed_docs = [] |
|
for i, para in enumerate(paragraphs): |
|
if len(para) > 10: |
|
processed_docs.append(Document( |
|
page_content=para, |
|
metadata={"source": markdown_path, "paragraph": i} |
|
)) |
|
|
|
return processed_docs, None |
|
|
|
except Exception as e: |
|
return None, f"Error processing document text: {str(e)}" |
|
|
|
|
|
def create_optimized_text_splitter(): |
|
"""Create an optimized text splitter for document processing""" |
|
return RecursiveCharacterTextSplitter( |
|
chunk_size=800, |
|
chunk_overlap=150, |
|
length_function=len, |
|
separators=["\n\n", "\n", ".", "!", "?", ";", ":", " ", ""] |
|
) |
|
|
|
|
|
def generate_summary(chunks: List[Document], length_type="sentences", length_count=3): |
|
"""Generate a summary from document chunks using the IBM Granite model |
|
|
|
Args: |
|
chunks: List of document chunks to summarize |
|
length_type: Either "sentences" or "paragraphs" |
|
length_count: Number of sentences (1-10) or paragraphs (1-3) |
|
""" |
|
|
|
print(f"Generating summary with length_type={length_type}, length_count={length_count}") |
|
|
|
|
|
try: |
|
length_count = int(length_count) |
|
except (ValueError, TypeError): |
|
print(f"Failed to convert length_count to int: {length_count}, using default 3") |
|
length_count = 3 |
|
|
|
|
|
if length_type == "sentences": |
|
length_count = max(1, min(10, length_count)) |
|
else: |
|
length_count = max(1, min(3, length_count)) |
|
|
|
|
|
|
|
cleaned_chunks = [] |
|
for chunk in chunks: |
|
text = chunk.page_content |
|
|
|
text = ' '.join(text.split()) |
|
cleaned_chunks.append(text) |
|
|
|
combined_text = " ".join(cleaned_chunks) |
|
|
|
|
|
if length_type == "sentences": |
|
length_instruction = f"Create a concise summary that is EXACTLY {length_count} complete sentences. Not {length_count-1} sentences. Not {length_count+1} sentences. EXACTLY {length_count} sentences." |
|
else: |
|
length_instruction = f"Create a concise summary that is EXACTLY {length_count} paragraphs. Each paragraph should be 2-4 sentences long. Not {length_count-1} paragraphs. Not {length_count+1} paragraphs. EXACTLY {length_count} paragraphs." |
|
|
|
|
|
prompt = f"""<instruction> |
|
You are an expert document summarizer. Your task is to create a high-quality summary of the following text. |
|
|
|
{length_instruction} |
|
|
|
Remember: |
|
- Your summary must capture the main points of the document |
|
- Your summary must be in your own words (not copied text) |
|
- Your summary must be clearly written and well-structured |
|
- Do not include any explanations, headings, bullet points, or additional formatting |
|
- Respond ONLY with the summary text itself |
|
|
|
</instruction> |
|
|
|
<text> |
|
{combined_text} |
|
</text> |
|
""" |
|
|
|
|
|
if length_type == "sentences": |
|
|
|
max_tokens = length_count * 40 |
|
else: |
|
|
|
max_tokens = length_count * 150 |
|
|
|
|
|
max_tokens = max(100, min(1500, max_tokens)) |
|
|
|
print(f"Using max_new_tokens={max_tokens}") |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=0.3, |
|
top_p=0.9, |
|
do_sample=True, |
|
repetition_penalty=1.2 |
|
) |
|
|
|
|
|
summary = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
summary = summary[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):] |
|
summary = summary.strip() |
|
|
|
|
|
if length_type == "sentences": |
|
|
|
sentences = [s.strip() for s in summary.split('.') if s.strip()] |
|
if len(sentences) > length_count: |
|
|
|
summary = '. '.join(sentences[:length_count]) + '.' |
|
elif len(sentences) < length_count: |
|
|
|
print(f"Warning: Generated only {len(sentences)} sentences instead of {length_count}") |
|
|
|
return summary.strip() |
|
|
|
|
|
def process_document_chunks(texts, batch_size=8): |
|
"""Process document chunks in efficient batches""" |
|
try: |
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
model_name="nomic-ai/nomic-embed-text-v1", |
|
model_kwargs={'trust_remote_code': True} |
|
) |
|
|
|
|
|
vectorstore = FAISS.from_documents( |
|
texts, |
|
embeddings, |
|
|
|
distance_strategy="cosine" |
|
) |
|
|
|
return vectorstore |
|
except Exception as e: |
|
print(f"Error in document processing: {str(e)}") |
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
model_name="nomic-ai/nomic-embed-text-v1", |
|
model_kwargs={'trust_remote_code': True} |
|
) |
|
return FAISS.from_documents(texts, embeddings) |
|
|
|
|
|
@spaces.GPU |
|
def process_document( |
|
file_obj: Optional[Union[str, tempfile._TemporaryFileWrapper]] = None, |
|
length_type: str = "sentences", |
|
length_count: int = 3, |
|
progress=gr.Progress() |
|
): |
|
"""Process a document file and generate a summary""" |
|
try: |
|
|
|
if not file_obj: |
|
return "Please provide a file to summarize." |
|
|
|
document_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj) |
|
|
|
|
|
format_type = get_document_format(document_path) |
|
if not format_type: |
|
return "Unsupported file format. Please upload a PDF, DOCX, PPTX, or HTML file." |
|
|
|
|
|
progress(0.3, "Converting document to markdown...") |
|
markdown_path = convert_document_to_markdown(document_path) |
|
if markdown_path.startswith("Error"): |
|
return markdown_path |
|
|
|
|
|
progress(0.4, "Processing document text...") |
|
processed_docs, error = clean_and_prepare_text(markdown_path) |
|
if error: |
|
return error |
|
|
|
|
|
text_splitter = create_optimized_text_splitter() |
|
texts = text_splitter.split_documents(processed_docs) |
|
|
|
if not texts: |
|
return "No text could be extracted from the document." |
|
|
|
|
|
progress(0.6, "Processing document content...") |
|
vectorstore = process_document_chunks(texts) |
|
|
|
|
|
retriever = vectorstore.as_retriever( |
|
search_type="similarity", |
|
search_kwargs={"k": 4} |
|
) |
|
|
|
|
|
progress(0.8, "Generating summary...") |
|
all_chunks = [] |
|
batch_size = 4 |
|
|
|
|
|
doc_ids = list(vectorstore.index_to_docstore_id.values()) |
|
|
|
|
|
for i in range(0, len(doc_ids), batch_size): |
|
batch_ids = doc_ids[i:i+batch_size] |
|
batch_chunks = [vectorstore.docstore.search(doc_id) for doc_id in batch_ids] |
|
all_chunks.extend(batch_chunks) |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
time.sleep(0.1) |
|
|
|
|
|
if len(all_chunks) <= 8: |
|
return generate_summary( |
|
all_chunks, |
|
length_type=length_type.lower(), |
|
length_count=length_count |
|
) |
|
|
|
|
|
elif len(all_chunks) <= 16: |
|
return generate_summary( |
|
all_chunks[:8], |
|
length_type=length_type.lower(), |
|
length_count=length_count |
|
) |
|
|
|
|
|
else: |
|
|
|
summaries = [] |
|
for i in range(0, len(all_chunks), batch_size): |
|
batch = all_chunks[i:i+batch_size] |
|
summary = generate_summary( |
|
batch, |
|
length_type="paragraphs", |
|
length_count=1 |
|
) |
|
summaries.append(summary) |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
final_summary = generate_summary( |
|
[Document(page_content=s) for s in summaries], |
|
length_type=length_type.lower(), |
|
length_count=length_count |
|
) |
|
return final_summary |
|
|
|
except Exception as e: |
|
return f"Error processing document: {str(e)}" |
|
|
|
|
|
def create_gradio_interface(): |
|
"""Create and launch the Gradio interface""" |
|
with gr.Blocks(title="Granite Document Summarization") as app: |
|
gr.Markdown("# Granite Document Summarization") |
|
gr.Markdown("Upload a document to generate a summary.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
file_input = gr.File( |
|
label="Upload Document (PDF, DOCX, PPTX, HTML)", |
|
file_types=[".pdf", ".docx", ".doc", ".pptx", ".html", ".htm"] |
|
) |
|
|
|
with gr.Row(): |
|
length_type = gr.Radio( |
|
choices=["Sentences", "Paragraphs"], |
|
value="Sentences", |
|
label="Summary Length Type" |
|
) |
|
|
|
with gr.Row(): |
|
|
|
sentence_count = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=3, |
|
step=1, |
|
label="Number of Sentences", |
|
visible=True |
|
) |
|
|
|
|
|
paragraph_count = gr.Radio( |
|
choices=["1", "2", "3"], |
|
value="1", |
|
label="Number of Paragraphs", |
|
visible=False |
|
) |
|
|
|
submit_btn = gr.Button("Summarize", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
output = gr.TextArea( |
|
label="Summary", |
|
lines=15, |
|
max_lines=30 |
|
) |
|
|
|
|
|
def update_count_visibility(length_type): |
|
is_sentences = length_type == "Sentences" |
|
return [ |
|
gr.update(visible=is_sentences), |
|
gr.update(visible=not is_sentences) |
|
] |
|
|
|
length_type.change( |
|
fn=update_count_visibility, |
|
inputs=[length_type], |
|
outputs=[sentence_count, paragraph_count] |
|
) |
|
|
|
|
|
def process_document_wrapper(file, length_type, sentence_count, paragraph_count): |
|
|
|
length_type_lower = length_type.lower() |
|
|
|
print(f"Processing with length_type={length_type}, sentence_count={sentence_count}, paragraph_count={paragraph_count}") |
|
|
|
|
|
if length_type_lower == "sentences": |
|
|
|
try: |
|
count = int(sentence_count) |
|
count = max(1, min(10, count)) |
|
print(f"Using sentence count: {count}") |
|
except (ValueError, TypeError): |
|
print(f"Invalid sentence count: {sentence_count}, using default 3") |
|
count = 3 |
|
else: |
|
|
|
try: |
|
|
|
if isinstance(paragraph_count, str): |
|
count = int(paragraph_count) |
|
|
|
elif isinstance(paragraph_count, bool): |
|
count = 1 |
|
else: |
|
count = int(paragraph_count) |
|
|
|
count = max(1, min(3, count)) |
|
print(f"Using paragraph count: {count}") |
|
except (ValueError, TypeError): |
|
print(f"Invalid paragraph count: {paragraph_count}, using default 1") |
|
count = 1 |
|
|
|
return process_document(file, length_type_lower, count) |
|
|
|
submit_btn.click( |
|
fn=process_document_wrapper, |
|
inputs=[file_input, length_type, sentence_count, paragraph_count], |
|
outputs=output |
|
) |
|
|
|
gr.Markdown(""" |
|
## How to use: |
|
1. Upload a document (PDF, DOCX, PPTX, HTML) |
|
2. Choose your summary length preference: |
|
- Number of Sentences (1-10) |
|
- Number of Paragraphs (1-3) |
|
3. Click "Summarize" to process the document |
|
|
|
*This application uses the IBM Granite 3.3-8b model to generate summaries.* |
|
""") |
|
|
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
app = create_gradio_interface() |
|
app.launch() |