|
import gradio as gr |
|
from huggingface_hub import login |
|
from smolagents import HfApiModel, Tool, CodeAgent |
|
|
|
import os |
|
import sys |
|
import json |
|
|
|
if './lib' not in sys.path : |
|
sys.path.append('./lib') |
|
from ingestion_chroma import retrieve_info_from_db |
|
|
|
|
|
|
|
|
|
|
|
def find_key(data, target_key): |
|
if isinstance(data, dict): |
|
for key, value in data.items(): |
|
if key == target_key: |
|
return value |
|
else: |
|
result = find_key(value, target_key) |
|
if result is not None: |
|
return result |
|
return "Indicator not found" |
|
|
|
|
|
|
|
class Chroma_retrieverTool(Tool): |
|
name = "request" |
|
description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query." |
|
inputs = { |
|
"query": { |
|
"type": "string", |
|
"description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.", |
|
}, |
|
} |
|
output_type = "string" |
|
|
|
def forward(self, query: str) -> str: |
|
assert isinstance(query, str), "The request needs to be a string." |
|
|
|
query_results = retrieve_info_from_db(query) |
|
str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))]) |
|
|
|
return str_result |
|
|
|
|
|
|
|
class ESRS_info_tool(Tool): |
|
name = "find_ESRS" |
|
description = "Find ESRS description to help you to find what indicators the user want" |
|
inputs = { |
|
"indicator": { |
|
"type": "string", |
|
"description": "The indicator name. return the description of the indicator demanded.", |
|
}, |
|
} |
|
output_type = "string" |
|
|
|
def forward(self, indicator: str) -> str: |
|
assert isinstance(indicator, str), "The request needs to be a string." |
|
|
|
with open('./data/dico_esrs.json') as json_data: |
|
dico_esrs = json.load(json_data) |
|
|
|
result = find_key(dico_esrs, indicator) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def respond(message, |
|
history: list[tuple[str, str]], |
|
system_message, |
|
max_tokens, |
|
temperature, |
|
top_p,): |
|
system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database. |
|
User's question : """ |
|
agent_output = agent.run(system_prompt_added + message) |
|
|
|
yield agent_output |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN_all") |
|
login(hf_token) |
|
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct") |
|
|
|
retriever_tool = Chroma_retrieverTool() |
|
get_ESRS_info_tool = ESRS_info_tool() |
|
agent = CodeAgent( |
|
tools=[ |
|
get_ESRS_info_tool, |
|
retriever_tool, |
|
], |
|
model=model, |
|
max_steps=10, |
|
max_print_outputs_length=16000, |
|
additional_authorized_imports=['pandas', 'matplotlib', 'datetime'] |
|
) |
|
|
|
|
|
""" |
|
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
|
""" |
|
demo = gr.ChatInterface( |
|
respond, |
|
additional_inputs=[ |
|
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), |
|
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), |
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="Top-p (nucleus sampling)", |
|
), |
|
], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|