Multimodal_RAG / app.py
ayyuce's picture
Create app.py
aea1ea5 verified
raw
history blame
6.44 kB
import gradio as gr
import os
import torch
from transformers import AutoProcessor, AutoModel
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from PIL import Image
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
class MultimodalRAG:
def __init__(self, pdf_path=None):
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
self.pdf_path = pdf_path
self.documents = []
self.vector_store = None
self.retriever = None
self.qa_chain = None
try:
self.llm = HuggingFacePipeline.from_model_id(
model_id="google/flan-t5-large",
task="text2text-generation",
model_kwargs={"temperature": 0.7, "max_length": 512}
)
except Exception as e:
print(f"Error loading flan-t5 model: {e}")
from langchain.llms import OpenAI
self.llm = OpenAI(temperature=0.7)
if pdf_path and os.path.exists(pdf_path):
self.load_pdf(pdf_path)
def load_pdf(self, pdf_path):
if not os.path.exists(pdf_path):
raise FileNotFoundError(f"PDF file not found: {pdf_path}")
loader = PyPDFLoader(pdf_path)
self.documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
self.documents = text_splitter.split_documents(self.documents)
self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings)
self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2})
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.retriever,
return_source_documents=True
)
return f"Successfully loaded and processed PDF: {pdf_path}"
def process_image(self, image_path):
if not os.path.exists(image_path):
print(f"Warning: Image path {image_path} does not exist")
return None
image = Image.open(image_path)
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
image_features = self.vision_model.get_image_features(**inputs)
return image_features
def generate_image_description(self, image_features):
return "a photo"
def retrieve_related_documents(self, query_text, image_path=None):
if image_path:
image_features = self.process_image(image_path)
if image_features is not None:
image_query = self.generate_image_description(image_features)
enhanced_query = f"{query_text} {image_query}"
else:
enhanced_query = query_text
else:
enhanced_query = query_text
docs = self.retriever.get_relevant_documents(enhanced_query)
return docs
def answer_query(self, query_text, image_path=None):
if not self.vector_store or not self.qa_chain:
return "Please upload a PDF document first."
if image_path:
docs = self.retrieve_related_documents(query_text, image_path)
else:
docs = self.retrieve_related_documents(query_text)
result = self.qa_chain({"query": query_text})
answer = result["result"]
sources = [doc.page_content[:1000] + "..." for doc in result["source_documents"]]
return answer, sources
rag_system = MultimodalRAG()
def upload_pdf(pdf_file):
if pdf_file is None:
return "No file uploaded"
file_path = pdf_file.name
try:
result = rag_system.load_pdf(file_path)
return result
except Exception as e:
return f"Error processing PDF: {str(e)}"
def save_image(image):
if image is None:
return None
temp_path = "temp_image.jpg"
image.save(temp_path)
return temp_path
def process_query(query, pdf_file, image=None):
if not query.strip():
return "Please enter a question", []
if pdf_file is None:
return "Please upload a PDF document first", []
image_path = None
if image is not None:
image_path = save_image(image)
try:
answer, sources = rag_system.answer_query(query, image_path)
if image_path and os.path.exists(image_path):
os.remove(image_path)
return answer, sources
except Exception as e:
if image_path and os.path.exists(image_path):
os.remove(image_path)
return f"Error processing query: {str(e)}", []
# Create Gradio interface
with gr.Blocks(title="Multimodal RAG System") as demo:
gr.Markdown("# Multimodal RAG System")
gr.Markdown("Upload a PDF document and ask questions about it. You can also add an image for multimodal context.")
with gr.Row():
with gr.Column(scale=1):
pdf_input = gr.File(label="Upload PDF Document")
upload_button = gr.Button("Process PDF")
status_output = gr.Textbox(label="Status")
upload_button.click(
fn=upload_pdf,
inputs=[pdf_input],
outputs=[status_output]
)
with gr.Column(scale=2):
image_input = gr.Image(label="Optional: Upload an Image", type="pil")
query_input = gr.Textbox(label="Ask a question")
submit_button = gr.Button("Submit Question")
answer_output = gr.Textbox(label="Answer")
sources_output = gr.JSON(label="Sources")
submit_button.click(
fn=process_query,
inputs=[query_input, pdf_input, image_input],
outputs=[answer_output, sources_output]
)
if __name__ == "__main__":
demo.launch(share=True, server_name="0.0.0.0")