Spaces:
Paused
Paused
import gradio as gr | |
import os | |
from smolagents import Tool, CodeAgent, TransformersModel, stream_to_gradio, HfApiModel | |
import spaces | |
from dotenv import load_dotenv | |
import datasets | |
from langchain.docstore.document import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import HuggingFaceDatasetLoader | |
import chromadb | |
from chromadb.utils import embedding_functions | |
""" | |
sample questions | |
pass me some fun general facts from the retreiver | |
""" | |
load_dotenv() | |
def dummy(): | |
pass | |
class RetrieverTool(Tool): | |
"""Since we need to add a vectordb as an attribute of the tool, | |
we cannot simply use the simple tool constructor with a @tool decorator | |
Used bm25 retrival method because it is fast. | |
For more accuracy in retrival, you can replace it with semantic search | |
using vector representations for documents. | |
check out MTEB Leaderboard for accuracy ranking | |
""" | |
name = "retriever" | |
description = """Uses semantic search to retrieve the parts of transformers documentation | |
that could be most relevant to answer your query. | |
Afterwards, this tool summaries the findings from the extracted document | |
""" | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The python list of queries to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", | |
} | |
} | |
output_type = "string" | |
def __init__(self, docs: list[Document], **kwargs): | |
super().__init__(**kwargs) | |
chroma_data_path = "chroma_data/" | |
if not os.path.isdir(chroma_data_path): | |
print("in if clause") | |
os.makedirs(chroma_data_path, exist_ok=True) | |
collection_name = "demo_docs" | |
embedding_func = embedding_functions.DefaultEmbeddingFunction() | |
client = chromadb.PersistentClient(path=chroma_data_path) | |
collection = client.get_or_create_collection( | |
name=collection_name, | |
embedding_function=embedding_func, | |
metadata={"hnsw:space": "cosine"}, | |
) | |
collection.upsert( | |
documents=[doc.page_content for doc in docs], | |
ids=[f"id{i}" for i in range(len(docs))], | |
) | |
self.collection = collection | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
docs = self.collection.query(query_texts=[query], n_results=5) | |
retrieved_text = "\nRetrieved documents:\n" + "".join( | |
[ | |
f"\n\n===== Document {str(i)} =====\n" + doc | |
for i, doc in zip(docs["ids"][0], docs["documents"][0]) | |
] | |
) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "summaries this text:" + retrieved_text} | |
], | |
} | |
] | |
return retrieved_text + "\n" + model(messages).content | |
if __name__ == "__main__": | |
# knowledge_base = datasets.load_dataset("MuskumPillerum/General-Knowledge", split="train") | |
# print(knowledge_base.column_names) | |
# source_docs = [ | |
# Document( | |
# page_content=doc["Answer"], metadata={"question": doc["Question"]} | |
# ) | |
# for doc in knowledge_base | |
# ] | |
source_docs = HuggingFaceDatasetLoader("MuskumPillerum/General-Knowledge", "Answer").load()[:100] | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
add_start_index=True, | |
strip_whitespace=True, | |
separators=["\n\n", "\n", ".", " ", ""], | |
) | |
docs_processed = text_splitter.split_documents(source_docs) | |
retriever_tool = RetrieverTool(docs_processed) | |
# Not working at the moment | |
# model = TransformersModel( | |
# # model_id="Qwen/Qwen2.5-Coder-32B-Instruct", | |
# device_map="cuda", | |
# model_id="meta-llama/Llama-3.2-3B-Instruct" | |
# ) | |
model = HfApiModel( | |
model_id="meta-llama/Llama-3.2-3B-Instruct", | |
token=os.getenv("my_first_agents_hf_tokens") | |
) | |
agent = CodeAgent( | |
tools=[retriever_tool], | |
model=model, | |
max_steps=10, | |
verbosity_level=10, | |
) | |
def enter_message(new_message, conversation_history): | |
conversation_history.append(gr.ChatMessage(role="user", content=new_message)) | |
for msg in stream_to_gradio(agent, new_message): | |
conversation_history.append(msg) | |
yield "", conversation_history | |
def clear_message(chat_history: list): | |
agent.memory.reset() | |
return chat_history.clear(), "" | |
with gr.Blocks() as b: | |
gr.Markdown("# Demo agentic rag on some general knowledge.") | |
chatbot = gr.Chatbot(type="messages", height=1000) | |
textbox = gr.Textbox(lines=1, label="chat message (with default sample question)", | |
value="pass me some fun general facts from the retreiver tool") | |
with gr.Row(): | |
clear_messages_button = gr.ClearButton([textbox, chatbot]) | |
stop_generating_button = gr.Button("stop generating") | |
enter_button = gr.Button("enter") | |
reply_button_click_event = enter_button.click(enter_message, [textbox, chatbot], [textbox, chatbot]) | |
submit_event = textbox.submit(enter_message, [textbox, chatbot], [textbox, chatbot]) | |
clear_messages_button.click(fn=clear_message, inputs=chatbot, outputs=[chatbot, textbox], | |
cancels=[reply_button_click_event, submit_event]) | |
stop_generating_button.click(cancels=[reply_button_click_event, submit_event]) | |
b.launch() | |