agentic_rag / app.py
themissingCRAM
ui and clear button
ccfea6c
raw
history blame contribute delete
5.76 kB
import gradio as gr
import os
from smolagents import Tool, CodeAgent, TransformersModel, stream_to_gradio, HfApiModel
import spaces
from dotenv import load_dotenv
import datasets
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import HuggingFaceDatasetLoader
import chromadb
from chromadb.utils import embedding_functions
"""
sample questions
pass me some fun general facts from the retreiver
"""
load_dotenv()
@spaces.GPU
def dummy():
pass
class RetrieverTool(Tool):
"""Since we need to add a vectordb as an attribute of the tool,
we cannot simply use the simple tool constructor with a @tool decorator
Used bm25 retrival method because it is fast.
For more accuracy in retrival, you can replace it with semantic search
using vector representations for documents.
check out MTEB Leaderboard for accuracy ranking
"""
name = "retriever"
description = """Uses semantic search to retrieve the parts of transformers documentation
that could be most relevant to answer your query.
Afterwards, this tool summaries the findings from the extracted document
"""
inputs = {
"query": {
"type": "string",
"description": "The python list of queries to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, docs: list[Document], **kwargs):
super().__init__(**kwargs)
chroma_data_path = "chroma_data/"
if not os.path.isdir(chroma_data_path):
print("in if clause")
os.makedirs(chroma_data_path, exist_ok=True)
collection_name = "demo_docs"
embedding_func = embedding_functions.DefaultEmbeddingFunction()
client = chromadb.PersistentClient(path=chroma_data_path)
collection = client.get_or_create_collection(
name=collection_name,
embedding_function=embedding_func,
metadata={"hnsw:space": "cosine"},
)
collection.upsert(
documents=[doc.page_content for doc in docs],
ids=[f"id{i}" for i in range(len(docs))],
)
self.collection = collection
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.collection.query(query_texts=[query], n_results=5)
retrieved_text = "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + doc
for i, doc in zip(docs["ids"][0], docs["documents"][0])
]
)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "summaries this text:" + retrieved_text}
],
}
]
return retrieved_text + "\n" + model(messages).content
if __name__ == "__main__":
# knowledge_base = datasets.load_dataset("MuskumPillerum/General-Knowledge", split="train")
# print(knowledge_base.column_names)
# source_docs = [
# Document(
# page_content=doc["Answer"], metadata={"question": doc["Question"]}
# )
# for doc in knowledge_base
# ]
source_docs = HuggingFaceDatasetLoader("MuskumPillerum/General-Knowledge", "Answer").load()[:100]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
docs_processed = text_splitter.split_documents(source_docs)
retriever_tool = RetrieverTool(docs_processed)
# Not working at the moment
# model = TransformersModel(
# # model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
# device_map="cuda",
# model_id="meta-llama/Llama-3.2-3B-Instruct"
# )
model = HfApiModel(
model_id="meta-llama/Llama-3.2-3B-Instruct",
token=os.getenv("my_first_agents_hf_tokens")
)
agent = CodeAgent(
tools=[retriever_tool],
model=model,
max_steps=10,
verbosity_level=10,
)
def enter_message(new_message, conversation_history):
conversation_history.append(gr.ChatMessage(role="user", content=new_message))
for msg in stream_to_gradio(agent, new_message):
conversation_history.append(msg)
yield "", conversation_history
def clear_message(chat_history: list):
agent.memory.reset()
return chat_history.clear(), ""
with gr.Blocks() as b:
gr.Markdown("# Demo agentic rag on some general knowledge.")
chatbot = gr.Chatbot(type="messages", height=1000)
textbox = gr.Textbox(lines=1, label="chat message (with default sample question)",
value="pass me some fun general facts from the retreiver tool")
with gr.Row():
clear_messages_button = gr.ClearButton([textbox, chatbot])
stop_generating_button = gr.Button("stop generating")
enter_button = gr.Button("enter")
reply_button_click_event = enter_button.click(enter_message, [textbox, chatbot], [textbox, chatbot])
submit_event = textbox.submit(enter_message, [textbox, chatbot], [textbox, chatbot])
clear_messages_button.click(fn=clear_message, inputs=chatbot, outputs=[chatbot, textbox],
cancels=[reply_button_click_event, submit_event])
stop_generating_button.click(cancels=[reply_button_click_event, submit_event])
b.launch()