File size: 3,664 Bytes
aea1ea5
 
ae8cf80
 
aea1ea5
 
 
 
 
 
 
 
ae8cf80
 
aea1ea5
ae8cf80
aea1ea5
 
 
ae8cf80
 
 
 
aea1ea5
 
 
 
ae8cf80
aea1ea5
ae8cf80
aea1ea5
 
 
 
 
 
 
 
 
 
 
ae8cf80
 
 
 
 
aea1ea5
ae8cf80
 
aea1ea5
 
ae8cf80
aea1ea5
ae8cf80
 
aea1ea5
 
 
ae8cf80
 
 
aea1ea5
ae8cf80
aea1ea5
ae8cf80
aea1ea5
ae8cf80
aea1ea5
 
 
ae8cf80
aea1ea5
ae8cf80
 
 
aea1ea5
ae8cf80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aea1ea5
 
ae8cf80
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import torch
from PIL import Image
import gradio as gr
from transformers import AutoProcessor, AutoModel
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

os.environ["CUDA_VISIBLE_DEVICES"] = ""

class MultimodalRAG:
    def __init__(self, pdf_path):
        self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
        self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

        self.documents = self._load_and_split(pdf_path)
        self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings)

        try:
            self.llm = HuggingFacePipeline.from_model_id(
                model_id="google/flan-t5-large",
                task="text2text-generation",
                model_kwargs={"temperature": 0.7, "max_length": 512, "device": -1}
            )
        except Exception:
            from langchain.llms import OpenAI
            self.llm = OpenAI(temperature=0.7)

        self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2})
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.retriever,
            return_source_documents=True
        )

    def _load_and_split(self, pdf_path):
        loader = PyPDFLoader(pdf_path)
        docs = loader.load()
        splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        return splitter.split_documents(docs)

    def _get_image_features(self, image_path):
        image = Image.open(image_path).convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt")
        with torch.no_grad():
            return self.vision_model.get_image_features(**inputs)

    def _generate_image_description(self, image_features):
        return "an image"

    def answer_query(self, query_text, image_path=None):
        if image_path:
            feats = self._get_image_features(image_path)
            img_desc = self._generate_image_description(feats)
            full_query = f"{query_text} {img_desc}"
        else:
            full_query = query_text

        result = self.qa_chain({"query": full_query})
        answer = result["result"]
        sources = [doc.metadata for doc in result.get("source_documents", [])]
        return answer, sources


def run_rag(pdf_file, query, image_file=None):
    if pdf_file is None:
        return "Please upload a PDF.", []

    pdf_path = pdf_file.name
    image_path = None
    if image_file:
        image_path = image_file.name

    rag = MultimodalRAG(pdf_path)
    answer, sources = rag.answer_query(query, image_path)
    return answer, sources

iface = gr.Interface(
    fn=run_rag,
    inputs=[
        gr.File(label="PDF Document", file_types=[".pdf"]),
        gr.Textbox(label="Query", placeholder="Enter your question here..."),
        gr.File(label="Optional Image", file_types=[".png", ".jpg", ".jpeg"], optional=True)
    ],
    outputs=[
        gr.Textbox(label="Answer"),
        gr.JSON(label="Source Documents")
    ],
    title="Multimodal RAG QA",
    description="Upload a PDF, ask a question, optionally provide an image."
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)