|
import gradio as gr |
|
from smolagents import HfApiModel |
|
import sys |
|
if './lib' not in sys.path : |
|
sys.path.append('./lib') |
|
from ingestion_chroma import retrieve_info_from_db |
|
|
|
|
|
|
|
|
|
|
|
def find_key(data, target_key): |
|
if isinstance(data, dict): |
|
for key, value in data.items(): |
|
if key == target_key: |
|
return value |
|
else: |
|
result = find_key(value, target_key) |
|
if result is not None: |
|
return result |
|
return "Indicator not found" |
|
|
|
|
|
|
|
class Chroma_retrieverTool(Tool): |
|
name = "request" |
|
description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query." |
|
inputs = { |
|
"query": { |
|
"type": "string", |
|
"description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.", |
|
}, |
|
} |
|
output_type = "string" |
|
|
|
def forward(self, query: str) -> str: |
|
assert isinstance(query, str), "The request needs to be a string." |
|
|
|
query_results = retrieve_info_from_db(query) |
|
str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))]) |
|
|
|
return str_result |
|
|
|
|
|
|
|
class ESRS_info_tool(Tool): |
|
name = "find_ESRS" |
|
description = "Find ESRS description to help you to find what indicators the user want" |
|
inputs = { |
|
"indicator": { |
|
"type": "string", |
|
"description": "The indicator name. return the description of the indicator demanded.", |
|
}, |
|
} |
|
output_type = "string" |
|
|
|
def forward(self, indicator: str) -> str: |
|
assert isinstance(indicator, str), "The request needs to be a string." |
|
|
|
with open('./data/dico_esrs.json') as json_data: |
|
dico_esrs = json.load(json_data) |
|
|
|
result = find_key(dico_esrs, indicator) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct") |
|
|
|
retriever_tool = Chroma_retrieverTool() |
|
get_ESRS_info_tool = ESRS_info_tool() |
|
agent = CodeAgent( |
|
tools=[ |
|
get_ESRS_info_tool, |
|
retriever_tool, |
|
], |
|
model=model, |
|
max_steps=10, |
|
max_print_outputs_length=16000, |
|
additional_authorized_imports=['pandas', 'matplotlib', 'datetime'] |
|
) |
|
|
|
|
|
def respond(message): |
|
system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database. |
|
User's question : """ |
|
agent_output = agent.run(system_prompt_added+"""Find all informations about the ESRS E1–5: Energy consumption from fossil sources in Sartorius documents.""") |
|
|
|
yield agent_output |
|
|
|
|
|
""" |
|
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
|
""" |
|
demo = gr.ChatInterface( |
|
respond, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|