Spaces:
Runtime error
Runtime error
Upload app_open.py
Browse files- app_open.py +112 -0
app_open.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
from langchain_openai import OpenAIEmbeddings
|
5 |
+
from langchain_postgres.vectorstores import PGVector
|
6 |
+
from langchain_openai import ChatOpenAI
|
7 |
+
from langchain.schema import HumanMessage
|
8 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
9 |
+
from langchain.chains import create_history_aware_retriever
|
10 |
+
from langchain.chains import create_retrieval_chain
|
11 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
12 |
+
|
13 |
+
import qdrant_client
|
14 |
+
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
15 |
+
from llama_index.core import VectorStoreIndex, StorageContext
|
16 |
+
from llama_index.core import SimpleDirectoryReader
|
17 |
+
from llama_index.core.indices.multi_modal.base import MultiModalVectorStoreIndex
|
18 |
+
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
|
19 |
+
|
20 |
+
os.environ["OPENAI_API_KEY"] = "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"
|
21 |
+
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
22 |
+
chat_llm = ChatOpenAI(temperature = 0.5, model = 'gpt-4-turbo')
|
23 |
+
|
24 |
+
contextualize_q_system_prompt = """Given a chat history and the latest user question \
|
25 |
+
which might reference context in the chat history, formulate a standalone question \
|
26 |
+
which can be understood without the chat history. Do NOT answer the question, \
|
27 |
+
just reformulate it if needed and otherwise return it as is."""
|
28 |
+
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
29 |
+
[
|
30 |
+
("system", contextualize_q_system_prompt),
|
31 |
+
MessagesPlaceholder("chat_history"),
|
32 |
+
("human", "{input}"),
|
33 |
+
]
|
34 |
+
)
|
35 |
+
|
36 |
+
qa_system_prompt = """You are an assistant for question-answering tasks. \
|
37 |
+
Use the following pieces of retrieved context to answer the question. \
|
38 |
+
If you don't know the answer, just say that you don't know. \
|
39 |
+
|
40 |
+
context: {context}"""
|
41 |
+
qa_prompt = ChatPromptTemplate.from_messages(
|
42 |
+
[
|
43 |
+
("system", qa_system_prompt),
|
44 |
+
MessagesPlaceholder("chat_history"),
|
45 |
+
("human", "{input}"),
|
46 |
+
]
|
47 |
+
)
|
48 |
+
question_answer_chain = create_stuff_documents_chain(chat_llm, qa_prompt)
|
49 |
+
|
50 |
+
# pg_connection = "postgresql+psycopg://postgres:3434@localhost:5433/mmrag"
|
51 |
+
pg_connection = "postgresql+psycopg://postgres:[email protected]:5432/postgres"
|
52 |
+
qd_client = qdrant_client.QdrantClient(path="qdrant_db")
|
53 |
+
image_store = QdrantVectorStore(client=qd_client, collection_name="image_collection")
|
54 |
+
storage_context = StorageContext.from_defaults(image_store=image_store)
|
55 |
+
openai_mm_llm = OpenAIMultiModal(model="gpt-4o", max_new_tokens=1500)
|
56 |
+
|
57 |
+
def response(message, history, doc_label):
|
58 |
+
|
59 |
+
text_store = PGVector(collection_name=doc_label,
|
60 |
+
embeddings=embeddings,
|
61 |
+
connection=pg_connection)
|
62 |
+
retriever = text_store.as_retriever()
|
63 |
+
history_aware_retriever = create_history_aware_retriever(chat_llm,
|
64 |
+
retriever,
|
65 |
+
contextualize_q_prompt)
|
66 |
+
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
67 |
+
|
68 |
+
response = rag_chain.invoke({"input": message, "chat_history": chat_history})
|
69 |
+
chat_history.extend([HumanMessage(content=message), response["answer"]])
|
70 |
+
|
71 |
+
return response["answer"]
|
72 |
+
|
73 |
+
def img_retrieve(query, doc_label):
|
74 |
+
doc_imgs = SimpleDirectoryReader(f"./{doc_label}").load_data()
|
75 |
+
index = MultiModalVectorStoreIndex.from_documents(doc_imgs,
|
76 |
+
storage_context=storage_context)
|
77 |
+
img_query_engine = index.as_query_engine(llm=openai_mm_llm,
|
78 |
+
image_similarity_top_k=3)
|
79 |
+
response_mm = img_query_engine.query(query)
|
80 |
+
retrieved_imgs = [n.metadata["file_path"] for n in response_mm.metadata["image_nodes"]]
|
81 |
+
return retrieved_imgs
|
82 |
+
|
83 |
+
chat_history = []
|
84 |
+
|
85 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
86 |
+
with gr.Row():
|
87 |
+
gr.Markdown(
|
88 |
+
"""
|
89 |
+
# 🎨 Multi-modal RAG Chatbot
|
90 |
+
""")
|
91 |
+
with gr.Row():
|
92 |
+
gr.Markdown("""Select document from the menu, and interact with the text and images in the document.
|
93 |
+
""")
|
94 |
+
with gr.Row():
|
95 |
+
with gr.Column(scale=2):
|
96 |
+
doc_label = gr.Dropdown(["LLaVA", "Interior"], label="Select a document:")
|
97 |
+
chatbot = gr.ChatInterface(fn=response, additional_inputs=[doc_label], fill_height=True)
|
98 |
+
with gr.Column(scale=1):
|
99 |
+
sample_1 = "https://i.ytimg.com/vi/bLj_mR4Fnls/maxresdefault.jpg"
|
100 |
+
sample_2 = "https://i.ytimg.com/vi/bOJdHU99OO8/maxresdefault.jpg"
|
101 |
+
sample_3 = "https://blog.kakaocdn.net/dn/nqcUB/btrzYjTgjWl/jFFlIBrdkoKv4jbSyZbiEk/img.jpg"
|
102 |
+
gallery = gr.Gallery(label="Retrieved images",
|
103 |
+
show_label=True, preview=True,
|
104 |
+
object_fit="contain",
|
105 |
+
value=[(sample_1, 'sample_1'),
|
106 |
+
(sample_2, 'sample_2'),
|
107 |
+
(sample_3, 'sample_3')])
|
108 |
+
query = gr.Textbox(label="Enter query")
|
109 |
+
button = gr.Button(value="Retrieve images")
|
110 |
+
button.click(img_retrieve, [query, doc_label], gallery)
|
111 |
+
|
112 |
+
demo.launch(share=True)
|