Multimodal_RAG / app.py
ayyuce's picture
Update app.py
ae8cf80 verified
raw
history blame
3.66 kB
import os
import torch
from PIL import Image
import gradio as gr
from transformers import AutoProcessor, AutoModel
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
os.environ["CUDA_VISIBLE_DEVICES"] = ""
class MultimodalRAG:
def __init__(self, pdf_path):
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
self.documents = self._load_and_split(pdf_path)
self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings)
try:
self.llm = HuggingFacePipeline.from_model_id(
model_id="google/flan-t5-large",
task="text2text-generation",
model_kwargs={"temperature": 0.7, "max_length": 512, "device": -1}
)
except Exception:
from langchain.llms import OpenAI
self.llm = OpenAI(temperature=0.7)
self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2})
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.retriever,
return_source_documents=True
)
def _load_and_split(self, pdf_path):
loader = PyPDFLoader(pdf_path)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
return splitter.split_documents(docs)
def _get_image_features(self, image_path):
image = Image.open(image_path).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
return self.vision_model.get_image_features(**inputs)
def _generate_image_description(self, image_features):
return "an image"
def answer_query(self, query_text, image_path=None):
if image_path:
feats = self._get_image_features(image_path)
img_desc = self._generate_image_description(feats)
full_query = f"{query_text} {img_desc}"
else:
full_query = query_text
result = self.qa_chain({"query": full_query})
answer = result["result"]
sources = [doc.metadata for doc in result.get("source_documents", [])]
return answer, sources
def run_rag(pdf_file, query, image_file=None):
if pdf_file is None:
return "Please upload a PDF.", []
pdf_path = pdf_file.name
image_path = None
if image_file:
image_path = image_file.name
rag = MultimodalRAG(pdf_path)
answer, sources = rag.answer_query(query, image_path)
return answer, sources
iface = gr.Interface(
fn=run_rag,
inputs=[
gr.File(label="PDF Document", file_types=[".pdf"]),
gr.Textbox(label="Query", placeholder="Enter your question here..."),
gr.File(label="Optional Image", file_types=[".png", ".jpg", ".jpeg"], optional=True)
],
outputs=[
gr.Textbox(label="Answer"),
gr.JSON(label="Source Documents")
],
title="Multimodal RAG QA",
description="Upload a PDF, ask a question, optionally provide an image."
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)