File size: 4,522 Bytes
4d5a5ba e49d5ca f878b1b e49d5ca 2ce0b48 f878b1b 2ce0b48 4d5a5ba e49d5ca 2ce0b48 4d5a5ba 2ce0b48 4d5a5ba 2ce0b48 4d5a5ba 2ce0b48 4d5a5ba 178fcdd 564da1a 2ce0b48 4d5a5ba 2ce0b48 4d5a5ba 178fcdd 4d5a5ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
from smolagents import HfApiModel, Tool, CodeAgent
from huggingface_hub import HfApi
import os
import sys
import json
if './lib' not in sys.path :
sys.path.append('./lib')
from ingestion_chroma import retrieve_info_from_db
hf_token = os.getenv("HF_TOKEN_all")
api = HfApi()
api.login(token=hf_token)
############################################################################################
################################### TOOLS ##################################################
############################################################################################
def find_key(data, target_key):
if isinstance(data, dict):
for key, value in data.items():
if key == target_key:
return value
else:
result = find_key(value, target_key)
if result is not None:
return result
return "Indicator not found"
############################################################################################
class Chroma_retrieverTool(Tool):
name = "request"
description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query."
inputs = {
"query": {
"type": "string",
"description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.",
},
}
output_type = "string"
def forward(self, query: str) -> str:
assert isinstance(query, str), "The request needs to be a string."
query_results = retrieve_info_from_db(query)
str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))])
return str_result
############################################################################################
class ESRS_info_tool(Tool):
name = "find_ESRS"
description = "Find ESRS description to help you to find what indicators the user want"
inputs = {
"indicator": {
"type": "string",
"description": "The indicator name. return the description of the indicator demanded.",
},
}
output_type = "string"
def forward(self, indicator: str) -> str:
assert isinstance(indicator, str), "The request needs to be a string."
with open('./data/dico_esrs.json') as json_data:
dico_esrs = json.load(json_data)
result = find_key(dico_esrs, indicator)
return result
############################################################################################
############################################################################################
############################################################################################
def respond(message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,):
system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database.
User's question : """
agent_output = agent.run(system_prompt_added + message)
yield agent_output
############################################################################################
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
retriever_tool = Chroma_retrieverTool()
get_ESRS_info_tool = ESRS_info_tool()
agent = CodeAgent(
tools=[
get_ESRS_info_tool,
retriever_tool,
],
model=model,
max_steps=10,
max_print_outputs_length=16000,
additional_authorized_imports=['pandas', 'matplotlib', 'datetime']
)
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()
|