Spaces:
Paused
Paused
File size: 5,756 Bytes
f56053f d4e94c6 f56053f fa354b9 c69bef9 642dafe fa354b9 f7ada6e 1d4ddf1 f7ada6e eafecbf f7ada6e e0c970e f56053f 5b04227 fa354b9 edc9ea5 fa354b9 59f2a72 fa354b9 59f2a72 fa354b9 7c794db f56053f fa354b9 f7ada6e b626f02 f7ada6e fa354b9 edc9ea5 071e96e fa354b9 a441795 fa354b9 edc9ea5 5de93d8 44bd417 8466e43 5de93d8 34f4144 edc9ea5 3f74d57 edc9ea5 2f46a72 642dafe 071e96e 642dafe fa354b9 6004c9d a441795 b626f02 59f2a72 a441795 44bd417 59f2a72 60a84d8 80aab0c fa354b9 3f74d57 fa354b9 c69bef9 81c1287 fa354b9 c69bef9 ead9c31 d4e94c6 c69bef9 ccfea6c eafecbf c69bef9 3d139cb c69bef9 53bccaa 60a84d8 c160d17 f039d9f c69bef9 ccfea6c c69bef9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import gradio as gr
import os
from smolagents import Tool, CodeAgent, TransformersModel, stream_to_gradio, HfApiModel
import spaces
from dotenv import load_dotenv
import datasets
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import HuggingFaceDatasetLoader
import chromadb
from chromadb.utils import embedding_functions
"""
sample questions
pass me some fun general facts from the retreiver
"""
load_dotenv()
@spaces.GPU
def dummy():
pass
class RetrieverTool(Tool):
"""Since we need to add a vectordb as an attribute of the tool,
we cannot simply use the simple tool constructor with a @tool decorator
Used bm25 retrival method because it is fast.
For more accuracy in retrival, you can replace it with semantic search
using vector representations for documents.
check out MTEB Leaderboard for accuracy ranking
"""
name = "retriever"
description = """Uses semantic search to retrieve the parts of transformers documentation
that could be most relevant to answer your query.
Afterwards, this tool summaries the findings from the extracted document
"""
inputs = {
"query": {
"type": "string",
"description": "The python list of queries to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, docs: list[Document], **kwargs):
super().__init__(**kwargs)
chroma_data_path = "chroma_data/"
if not os.path.isdir(chroma_data_path):
print("in if clause")
os.makedirs(chroma_data_path, exist_ok=True)
collection_name = "demo_docs"
embedding_func = embedding_functions.DefaultEmbeddingFunction()
client = chromadb.PersistentClient(path=chroma_data_path)
collection = client.get_or_create_collection(
name=collection_name,
embedding_function=embedding_func,
metadata={"hnsw:space": "cosine"},
)
collection.upsert(
documents=[doc.page_content for doc in docs],
ids=[f"id{i}" for i in range(len(docs))],
)
self.collection = collection
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.collection.query(query_texts=[query], n_results=5)
retrieved_text = "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + doc
for i, doc in zip(docs["ids"][0], docs["documents"][0])
]
)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "summaries this text:" + retrieved_text}
],
}
]
return retrieved_text + "\n" + model(messages).content
if __name__ == "__main__":
# knowledge_base = datasets.load_dataset("MuskumPillerum/General-Knowledge", split="train")
# print(knowledge_base.column_names)
# source_docs = [
# Document(
# page_content=doc["Answer"], metadata={"question": doc["Question"]}
# )
# for doc in knowledge_base
# ]
source_docs = HuggingFaceDatasetLoader("MuskumPillerum/General-Knowledge", "Answer").load()[:100]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
docs_processed = text_splitter.split_documents(source_docs)
retriever_tool = RetrieverTool(docs_processed)
# Not working at the moment
# model = TransformersModel(
# # model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
# device_map="cuda",
# model_id="meta-llama/Llama-3.2-3B-Instruct"
# )
model = HfApiModel(
model_id="meta-llama/Llama-3.2-3B-Instruct",
token=os.getenv("my_first_agents_hf_tokens")
)
agent = CodeAgent(
tools=[retriever_tool],
model=model,
max_steps=10,
verbosity_level=10,
)
def enter_message(new_message, conversation_history):
conversation_history.append(gr.ChatMessage(role="user", content=new_message))
for msg in stream_to_gradio(agent, new_message):
conversation_history.append(msg)
yield "", conversation_history
def clear_message(chat_history: list):
agent.memory.reset()
return chat_history.clear(), ""
with gr.Blocks() as b:
gr.Markdown("# Demo agentic rag on some general knowledge.")
chatbot = gr.Chatbot(type="messages", height=1000)
textbox = gr.Textbox(lines=1, label="chat message (with default sample question)",
value="pass me some fun general facts from the retreiver tool")
with gr.Row():
clear_messages_button = gr.ClearButton([textbox, chatbot])
stop_generating_button = gr.Button("stop generating")
enter_button = gr.Button("enter")
reply_button_click_event = enter_button.click(enter_message, [textbox, chatbot], [textbox, chatbot])
submit_event = textbox.submit(enter_message, [textbox, chatbot], [textbox, chatbot])
clear_messages_button.click(fn=clear_message, inputs=chatbot, outputs=[chatbot, textbox],
cancels=[reply_button_click_event, submit_event])
stop_generating_button.click(cancels=[reply_button_click_event, submit_event])
b.launch()
|