import gradio as gr import os import torch from transformers import AutoProcessor, AutoModel from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains import RetrievalQA from langchain_community.llms import HuggingFacePipeline from PIL import Image from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter class MultimodalRAG: def __init__(self, pdf_path=None): self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") self.vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32") self.text_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") self.pdf_path = pdf_path self.documents = [] self.vector_store = None self.retriever = None self.qa_chain = None try: self.llm = HuggingFacePipeline.from_model_id( model_id="google/flan-t5-large", task="text2text-generation", model_kwargs={"temperature": 0.7, "max_length": 512} ) except Exception as e: print(f"Error loading flan-t5 model: {e}") from langchain.llms import OpenAI self.llm = OpenAI(temperature=0.7) if pdf_path and os.path.exists(pdf_path): self.load_pdf(pdf_path) def load_pdf(self, pdf_path): if not os.path.exists(pdf_path): raise FileNotFoundError(f"PDF file not found: {pdf_path}") loader = PyPDFLoader(pdf_path) self.documents = loader.load() text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200 ) self.documents = text_splitter.split_documents(self.documents) self.vector_store = FAISS.from_documents(self.documents, self.text_embeddings) self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 2}) self.qa_chain = RetrievalQA.from_chain_type( llm=self.llm, chain_type="stuff", retriever=self.retriever, return_source_documents=True ) return f"Successfully loaded and processed PDF: {pdf_path}" def process_image(self, image_path): if not os.path.exists(image_path): print(f"Warning: Image path {image_path} does not exist") return None image = Image.open(image_path) inputs = self.processor(images=image, return_tensors="pt") with torch.no_grad(): image_features = self.vision_model.get_image_features(**inputs) return image_features def generate_image_description(self, image_features): return "a photo" def retrieve_related_documents(self, query_text, image_path=None): if image_path: image_features = self.process_image(image_path) if image_features is not None: image_query = self.generate_image_description(image_features) enhanced_query = f"{query_text} {image_query}" else: enhanced_query = query_text else: enhanced_query = query_text docs = self.retriever.get_relevant_documents(enhanced_query) return docs def answer_query(self, query_text, image_path=None): if not self.vector_store or not self.qa_chain: return "Please upload a PDF document first." if image_path: docs = self.retrieve_related_documents(query_text, image_path) else: docs = self.retrieve_related_documents(query_text) result = self.qa_chain({"query": query_text}) answer = result["result"] sources = [doc.page_content[:1000] + "..." for doc in result["source_documents"]] return answer, sources rag_system = MultimodalRAG() def upload_pdf(pdf_file): if pdf_file is None: return "No file uploaded" file_path = pdf_file.name try: result = rag_system.load_pdf(file_path) return result except Exception as e: return f"Error processing PDF: {str(e)}" def save_image(image): if image is None: return None temp_path = "temp_image.jpg" image.save(temp_path) return temp_path def process_query(query, pdf_file, image=None): if not query.strip(): return "Please enter a question", [] if pdf_file is None: return "Please upload a PDF document first", [] image_path = None if image is not None: image_path = save_image(image) try: answer, sources = rag_system.answer_query(query, image_path) if image_path and os.path.exists(image_path): os.remove(image_path) return answer, sources except Exception as e: if image_path and os.path.exists(image_path): os.remove(image_path) return f"Error processing query: {str(e)}", [] # Create Gradio interface with gr.Blocks(title="Multimodal RAG System") as demo: gr.Markdown("# Multimodal RAG System") gr.Markdown("Upload a PDF document and ask questions about it. You can also add an image for multimodal context.") with gr.Row(): with gr.Column(scale=1): pdf_input = gr.File(label="Upload PDF Document") upload_button = gr.Button("Process PDF") status_output = gr.Textbox(label="Status") upload_button.click( fn=upload_pdf, inputs=[pdf_input], outputs=[status_output] ) with gr.Column(scale=2): image_input = gr.Image(label="Optional: Upload an Image", type="pil") query_input = gr.Textbox(label="Ask a question") submit_button = gr.Button("Submit Question") answer_output = gr.Textbox(label="Answer") sources_output = gr.JSON(label="Sources") submit_button.click( fn=process_query, inputs=[query_input, pdf_input, image_input], outputs=[answer_output, sources_output] ) if __name__ == "__main__": demo.launch(share=True, server_name="0.0.0.0")