File size: 5,756 Bytes
f56053f
 
d4e94c6
f56053f
 
fa354b9
 
 
c69bef9
 
642dafe
 
fa354b9
f7ada6e
1d4ddf1
f7ada6e
eafecbf
 
 
 
f7ada6e
e0c970e
 
 
f56053f
5b04227
 
 
fa354b9
 
edc9ea5
fa354b9
59f2a72
fa354b9
59f2a72
fa354b9
7c794db
f56053f
fa354b9
 
f7ada6e
 
b626f02
f7ada6e
fa354b9
 
edc9ea5
071e96e
fa354b9
 
 
 
a441795
fa354b9
edc9ea5
 
5de93d8
44bd417
8466e43
5de93d8
34f4144
edc9ea5
3f74d57
edc9ea5
2f46a72
642dafe
 
071e96e
642dafe
 
 
 
fa354b9
6004c9d
 
a441795
b626f02
59f2a72
a441795
44bd417
59f2a72
 
60a84d8
 
 
 
 
 
 
 
 
80aab0c
fa354b9
3f74d57
fa354b9
c69bef9
 
 
 
 
 
 
 
81c1287
fa354b9
 
 
 
 
 
 
 
c69bef9
ead9c31
d4e94c6
 
 
 
 
 
 
 
 
c69bef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccfea6c
eafecbf
c69bef9
 
3d139cb
c69bef9
53bccaa
60a84d8
c160d17
 
f039d9f
 
c69bef9
ccfea6c
 
 
 
c69bef9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
import os
from smolagents import Tool, CodeAgent, TransformersModel, stream_to_gradio, HfApiModel
import spaces
from dotenv import load_dotenv
import datasets
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter

from langchain_community.document_loaders import HuggingFaceDatasetLoader
import chromadb
from chromadb.utils import embedding_functions

"""
sample questions



pass me some fun general facts from the retreiver

"""

load_dotenv()

@spaces.GPU
def dummy():
    pass

class RetrieverTool(Tool):
    """Since we need to add a vectordb as an attribute of the tool,
    we cannot simply use the simple tool constructor with a @tool decorator

    Used bm25 retrival method because it is fast.
    For more accuracy in retrival, you can replace it with semantic search
    using vector representations for documents.

    check out MTEB Leaderboard for accuracy ranking
    """

    name = "retriever"
    description = """Uses semantic search to retrieve the parts of transformers documentation 
    that could be most relevant to answer your query. 
    Afterwards, this tool  summaries the findings from the extracted document
    """
    inputs = {
        "query": {
            "type": "string",
            "description": "The python list of queries to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
        }
    }
    output_type = "string"

    def __init__(self, docs: list[Document], **kwargs):
        super().__init__(**kwargs)
        chroma_data_path = "chroma_data/"

        if not os.path.isdir(chroma_data_path):
            print("in if clause")
            os.makedirs(chroma_data_path, exist_ok=True)
        collection_name = "demo_docs"
        embedding_func = embedding_functions.DefaultEmbeddingFunction()
        client = chromadb.PersistentClient(path=chroma_data_path)
        collection = client.get_or_create_collection(
            name=collection_name,
            embedding_function=embedding_func,
            metadata={"hnsw:space": "cosine"},
        )
        collection.upsert(
            documents=[doc.page_content for doc in docs],
            ids=[f"id{i}" for i in range(len(docs))],
        )
        self.collection = collection

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"
        docs = self.collection.query(query_texts=[query], n_results=5)
        retrieved_text = "\nRetrieved documents:\n" + "".join(
            [
                f"\n\n===== Document {str(i)} =====\n" + doc
                for i, doc in zip(docs["ids"][0], docs["documents"][0])
            ]
        )
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "summaries this text:" + retrieved_text}
                ],
            }
        ]
        return retrieved_text + "\n" + model(messages).content



if __name__ == "__main__":
    # knowledge_base = datasets.load_dataset("MuskumPillerum/General-Knowledge", split="train")
    # print(knowledge_base.column_names)
    # source_docs = [
    #     Document(
    #         page_content=doc["Answer"], metadata={"question": doc["Question"]}
    #     )
    #     for doc in knowledge_base
    # ]
    source_docs = HuggingFaceDatasetLoader("MuskumPillerum/General-Knowledge", "Answer").load()[:100]
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=50,
        add_start_index=True,
        strip_whitespace=True,
        separators=["\n\n", "\n", ".", " ", ""],
    )
    docs_processed = text_splitter.split_documents(source_docs)
    retriever_tool = RetrieverTool(docs_processed)

    # Not working at the moment
    # model = TransformersModel(
    #     # model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
    #     device_map="cuda",
    #     model_id="meta-llama/Llama-3.2-3B-Instruct"
    # )
    model = HfApiModel(
        model_id="meta-llama/Llama-3.2-3B-Instruct",
        token=os.getenv("my_first_agents_hf_tokens")
    )
    agent = CodeAgent(
        tools=[retriever_tool],
        model=model,
        max_steps=10,
        verbosity_level=10,
    )


    def enter_message(new_message, conversation_history):
        conversation_history.append(gr.ChatMessage(role="user", content=new_message))
        for msg in stream_to_gradio(agent, new_message):
            conversation_history.append(msg)
            yield "", conversation_history

    def clear_message(chat_history: list):
        agent.memory.reset()
        return chat_history.clear(), ""

    with gr.Blocks() as b:
        gr.Markdown("# Demo agentic rag on some general knowledge.")
        chatbot = gr.Chatbot(type="messages", height=1000)
        textbox = gr.Textbox(lines=1, label="chat message (with default sample question)",
                             value="pass me some fun general facts from the retreiver tool")
        with gr.Row():
            clear_messages_button = gr.ClearButton([textbox, chatbot])
            stop_generating_button = gr.Button("stop generating")
            enter_button = gr.Button("enter")
        reply_button_click_event = enter_button.click(enter_message, [textbox, chatbot], [textbox, chatbot])
        submit_event = textbox.submit(enter_message, [textbox, chatbot], [textbox, chatbot])
        clear_messages_button.click(fn=clear_message, inputs=chatbot, outputs=[chatbot, textbox],
                                    cancels=[reply_button_click_event, submit_event])
        stop_generating_button.click(cancels=[reply_button_click_event, submit_event])

    b.launch()