Spaces:
Running
Running
import os | |
import torch | |
from PIL import Image | |
import gradio as gr | |
from transformers import AutoProcessor, AutoModel | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
class MultimodalRAG: | |
def __init__(self, pdf_path): | |
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32") | |
self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
self.documents = self._load_and_split(pdf_path) | |
self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings) | |
try: | |
self.llm = HuggingFacePipeline.from_model_id( | |
model_id="google/flan-t5-large", | |
task="text2text-generation", | |
model_kwargs={"temperature": 0.7, "max_length": 512, "device": -1} | |
) | |
except Exception: | |
from langchain.llms import OpenAI | |
self.llm = OpenAI(temperature=0.7) | |
self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2}) | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.retriever, | |
return_source_documents=True | |
) | |
def _load_and_split(self, pdf_path): | |
loader = PyPDFLoader(pdf_path) | |
docs = loader.load() | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
return splitter.split_documents(docs) | |
def _get_image_features(self, image_path): | |
image = Image.open(image_path).convert("RGB") | |
inputs = self.processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
return self.vision_model.get_image_features(**inputs) | |
def _generate_image_description(self, image_features): | |
return "an image" | |
def answer_query(self, query_text, image_path=None): | |
if image_path: | |
feats = self._get_image_features(image_path) | |
img_desc = self._generate_image_description(feats) | |
full_query = f"{query_text} {img_desc}" | |
else: | |
full_query = query_text | |
result = self.qa_chain({"query": full_query}) | |
answer = result["result"] | |
sources = [doc.metadata for doc in result.get("source_documents", [])] | |
return answer, sources | |
def run_rag(pdf_file, query, image_file=None): | |
if pdf_file is None: | |
return "Please upload a PDF.", [] | |
pdf_path = pdf_file.name | |
image_path = None | |
if image_file: | |
image_path = image_file.name | |
rag = MultimodalRAG(pdf_path) | |
answer, sources = rag.answer_query(query, image_path) | |
return answer, sources | |
iface = gr.Interface( | |
fn=run_rag, | |
inputs=[ | |
gr.File(label="PDF Document", file_types=[".pdf"]), | |
gr.Textbox(label="Query", placeholder="Enter your question here..."), | |
gr.File(label="Optional Image", file_types=[".png", ".jpg", ".jpeg"], optional=True) | |
], | |
outputs=[ | |
gr.Textbox(label="Answer"), | |
gr.JSON(label="Source Documents") | |
], | |
title="Multimodal RAG QA", | |
description="Upload a PDF, ask a question, optionally provide an image." | |
) | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0", server_port=7860) | |