Spaces:
Running
Running
import gradio as gr | |
import os | |
import torch | |
from transformers import AutoProcessor, AutoModel | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain_community.llms import HuggingFacePipeline | |
from PIL import Image | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
class MultimodalRAG: | |
def __init__(self, pdf_path=None): | |
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32") | |
self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
self.pdf_path = pdf_path | |
self.documents = [] | |
self.vector_store = None | |
self.retriever = None | |
self.qa_chain = None | |
try: | |
self.llm = HuggingFacePipeline.from_model_id( | |
model_id="google/flan-t5-large", | |
task="text2text-generation", | |
model_kwargs={"temperature": 0.7, "max_length": 512} | |
) | |
except Exception as e: | |
print(f"Error loading flan-t5 model: {e}") | |
from langchain.llms import OpenAI | |
self.llm = OpenAI(temperature=0.7) | |
if pdf_path and os.path.exists(pdf_path): | |
self.load_pdf(pdf_path) | |
def load_pdf(self, pdf_path): | |
if not os.path.exists(pdf_path): | |
raise FileNotFoundError(f"PDF file not found: {pdf_path}") | |
loader = PyPDFLoader(pdf_path) | |
self.documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200 | |
) | |
self.documents = text_splitter.split_documents(self.documents) | |
self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings) | |
self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2}) | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.retriever, | |
return_source_documents=True | |
) | |
return f"Successfully loaded and processed PDF: {pdf_path}" | |
def process_image(self, image_path): | |
if not os.path.exists(image_path): | |
print(f"Warning: Image path {image_path} does not exist") | |
return None | |
image = Image.open(image_path) | |
inputs = self.processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
image_features = self.vision_model.get_image_features(**inputs) | |
return image_features | |
def generate_image_description(self, image_features): | |
return "a photo" | |
def retrieve_related_documents(self, query_text, image_path=None): | |
if image_path: | |
image_features = self.process_image(image_path) | |
if image_features is not None: | |
image_query = self.generate_image_description(image_features) | |
enhanced_query = f"{query_text} {image_query}" | |
else: | |
enhanced_query = query_text | |
else: | |
enhanced_query = query_text | |
docs = self.retriever.get_relevant_documents(enhanced_query) | |
return docs | |
def answer_query(self, query_text, image_path=None): | |
if not self.vector_store or not self.qa_chain: | |
return "Please upload a PDF document first." | |
if image_path: | |
docs = self.retrieve_related_documents(query_text, image_path) | |
else: | |
docs = self.retrieve_related_documents(query_text) | |
result = self.qa_chain({"query": query_text}) | |
answer = result["result"] | |
sources = [doc.page_content[:1000] + "..." for doc in result["source_documents"]] | |
return answer, sources | |
rag_system = MultimodalRAG() | |
def upload_pdf(pdf_file): | |
if pdf_file is None: | |
return "No file uploaded" | |
file_path = pdf_file.name | |
try: | |
result = rag_system.load_pdf(file_path) | |
return result | |
except Exception as e: | |
return f"Error processing PDF: {str(e)}" | |
def save_image(image): | |
if image is None: | |
return None | |
temp_path = "temp_image.jpg" | |
image.save(temp_path) | |
return temp_path | |
def process_query(query, pdf_file, image=None): | |
if not query.strip(): | |
return "Please enter a question", [] | |
if pdf_file is None: | |
return "Please upload a PDF document first", [] | |
image_path = None | |
if image is not None: | |
image_path = save_image(image) | |
try: | |
answer, sources = rag_system.answer_query(query, image_path) | |
if image_path and os.path.exists(image_path): | |
os.remove(image_path) | |
return answer, sources | |
except Exception as e: | |
if image_path and os.path.exists(image_path): | |
os.remove(image_path) | |
return f"Error processing query: {str(e)}", [] | |
# Create Gradio interface | |
with gr.Blocks(title="Multimodal RAG System") as demo: | |
gr.Markdown("# Multimodal RAG System") | |
gr.Markdown("Upload a PDF document and ask questions about it. You can also add an image for multimodal context.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
pdf_input = gr.File(label="Upload PDF Document") | |
upload_button = gr.Button("Process PDF") | |
status_output = gr.Textbox(label="Status") | |
upload_button.click( | |
fn=upload_pdf, | |
inputs=[pdf_input], | |
outputs=[status_output] | |
) | |
with gr.Column(scale=2): | |
image_input = gr.Image(label="Optional: Upload an Image", type="pil") | |
query_input = gr.Textbox(label="Ask a question") | |
submit_button = gr.Button("Submit Question") | |
answer_output = gr.Textbox(label="Answer") | |
sources_output = gr.JSON(label="Sources") | |
submit_button.click( | |
fn=process_query, | |
inputs=[query_input, pdf_input, image_input], | |
outputs=[answer_output, sources_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True, server_name="0.0.0.0") |