ofermend commited on
Commit
941e6a0
·
1 Parent(s): c9767eb
Files changed (3) hide show
  1. agent.py +93 -82
  2. requirements.txt +1 -1
  3. st_app.py +7 -0
agent.py CHANGED
@@ -3,11 +3,14 @@ import pandas as pd
3
  import requests
4
  from functools import lru_cache
5
  from pydantic import Field, BaseModel
 
6
 
7
  from omegaconf import OmegaConf
8
 
9
  from vectara_agentic.agent import Agent
10
  from vectara_agentic.tools import ToolsFactory, VectaraToolFactory
 
 
11
 
12
  from dotenv import load_dotenv
13
  load_dotenv(override=True)
@@ -34,7 +37,7 @@ years = range(2015, 2025)
34
  initial_prompt = "How can I help you today?"
35
 
36
  # Tool to get the income statement for a given company and year using the FMP API
37
- @lru_cache(maxsize=128)
38
  def fmp_income_statement(
39
  ticker: str = Field(description="the ticker symbol of the company.", examples=["AAPL", "GOOG", "AMZN"]),
40
  year: int = Field(description="the year for which to get the income statement.", examples=[2020, 2021, 2022]),
@@ -49,6 +52,8 @@ def fmp_income_statement(
49
  A dictionary with the income statement data.
50
  All data is in USD, but you can convert it to more compact form like K, M, B.
51
  """
 
 
52
  fmp_api_key = os.environ.get("FMP_API_KEY", None)
53
  if fmp_api_key is None:
54
  return "FMP_API_KEY environment variable not set. This tool does not work."
@@ -65,97 +70,99 @@ def fmp_income_statement(
65
  ]
66
  values_dict = income_statement_specific_year.to_dict(orient="records")[0]
67
  return f"Financial results: {', '.join([f'{key}={value}' for key, value in values_dict.items() if key not in ['date', 'cik', 'link', 'finalLink']])}"
68
-
69
  return f"FMP API returned error {response.status_code}. This tool does not work."
70
 
 
 
 
 
 
 
 
71
 
72
- def create_assistant_tools(cfg):
73
-
74
- def get_company_info() -> list[str]:
75
- """
76
- Returns a dictionary of companies you can query about. Always check this before using any other tool.
77
- The output is a dictionary of valid ticker symbols mapped to company names.
78
- You can use this to identify the companies you can query about, and their ticker information.
79
- """
80
- return tickers
81
-
82
- def get_valid_years() -> list[str]:
83
- """
84
- Returns a list of the years for which financial reports are available.
85
- Always check this before using any other tool.
86
- """
87
- return years
88
-
89
- class QueryTranscriptsArgs(BaseModel):
90
- query: str = Field(..., description="The user query, always in the form of a question", examples=["what are the risks reported?", "who are the competitors?"])
91
- year: int | str = Field(
92
- default=None,
93
- description=f"The year this query relates to. An integer between {min(years)} and {max(years)} or a string specifying a condition on the year",
94
- examples=[2020, '>2021', '<2023', '>=2021', '<=2023', '[2021, 2023]', '[2021, 2023)']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  )
96
- ticker: str = Field(..., description=f"The company ticker this query relates to. Must be a valid ticket symbol from the list {list(tickers.keys())}.")
97
-
98
- vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
99
- vectara_corpus_key=cfg.corpus_key)
100
- summarizer = 'vectara-experimental-summary-ext-2023-12-11-med-omni'
101
- ask_transcripts = vec_factory.create_rag_tool(
102
- tool_name = "ask_transcripts",
103
- tool_description = """
104
- Given a company name and year, responds to a user question about the company, based on analyst call transcripts about the company's financial reports for that year.
105
- You can ask this tool any question about the company including risks, opportunities, financial performance, competitors and more.
106
- """,
107
- tool_args_schema = QueryTranscriptsArgs,
108
- reranker = "multilingual_reranker_v1", rerank_k = 100,
109
- n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
110
- summary_num_results = 10,
111
- vectara_summarizer = summarizer,
112
- include_citations = True,
113
- verbose=False,
114
- )
115
 
116
- class SearchTranscriptsArgs(BaseModel):
117
- query: str = Field(..., description="The user query, always in the form of a question", examples=["what are the risks reported?", "who are the competitors?"])
118
- top_k: int = Field(..., description="The number of results to return.")
119
- year: int | str = Field(
120
- default=None,
121
- description=f"The year this query relates to. An integer between {min(years)} and {max(years)} or a string specifying a condition on the year",
122
- examples=[2020, '>2021', '<2023', '>=2021', '<=2023', '[2021, 2023]', '[2021, 2023)']
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
- ticker: str = Field(..., description=f"The company ticker this query relates to. Must be a valid ticket symbol from the list {list(tickers.keys())}.")
125
- search_transcripts = vec_factory.create_search_tool(
126
- tool_name = "search_transcripts",
127
- tool_description = """
128
- Given a company name and year, and a user query, retrieves the most relevant text from analyst call transcripts about the company related to the user query.
129
- """,
130
- tool_args_schema = SearchTranscriptsArgs,
131
- reranker = "multilingual_reranker_v1", rerank_k = 100,
132
- lambda_val = 0.005,
133
- verbose=False
134
- )
135
 
136
- tools_factory = ToolsFactory()
137
- return (
138
- [tools_factory.create_tool(tool) for tool in
139
- [
140
- get_company_info,
141
- get_valid_years,
142
- fmp_income_statement,
143
- ]
144
- ] +
145
- tools_factory.financial_tools() +
146
- [ask_transcripts, search_transcripts]
147
- )
148
 
149
  def initialize_agent(_cfg, agent_progress_callback=None):
150
  financial_bot_instructions = """
151
  - You are a helpful financial assistant, with expertise in financial reporting, in conversation with a user.
152
- - Always use the 'income_statement' tool to obtain accurate financial data like revenues, expenses, net income, and other financial metrics
153
- for a specific company, for any the year 2020 or later.
154
- - Use the 'fmp_income_statement' tool (with the company ticker and year) to obtain financial data for any year before 2020,
155
- - Use the 'fmp_income_statement` tool (with the company ticker and year) to obtain financial data for any year on or after 2020, when the 'income_statement'
156
- did not return any data useful to respond to the user query.
157
  - Always check the 'get_company_info' and 'get_valid_years' tools to validate company and year are valid.
158
- - Use the ask_transcripts tool to answer most questions about the company's financial performance, risks, opportunities, strategy, competitors, and more.
159
  - Respond in a compact format by using appropriate units of measure (e.g., K for thousands, M for millions, B for billions).
160
  Do not report the same number twice (e.g. $100K and 100,000 USD).
161
  - Do not include URLs unless they are provided in the output of a tool you use.
@@ -165,17 +172,21 @@ def initialize_agent(_cfg, agent_progress_callback=None):
165
  def query_logging(query: str, response: str):
166
  print(f"Logging query={query}, response={response}")
167
 
 
 
168
  agent = Agent(
169
- tools=create_assistant_tools(_cfg),
170
  topic="Financial data, annual reports and 10-K filings",
171
  custom_instructions=financial_bot_instructions,
172
  agent_progress_callback=agent_progress_callback,
173
  query_logging_callback=query_logging,
 
 
174
  )
 
175
  agent.report()
176
  return agent
177
 
178
-
179
  def get_agent_config() -> OmegaConf:
180
  companies = ", ".join(tickers.values())
181
  cfg = OmegaConf.create({
 
3
  import requests
4
  from functools import lru_cache
5
  from pydantic import Field, BaseModel
6
+ from typing import Any, Optional
7
 
8
  from omegaconf import OmegaConf
9
 
10
  from vectara_agentic.agent import Agent
11
  from vectara_agentic.tools import ToolsFactory, VectaraToolFactory
12
+ from vectara_agentic.agent_config import AgentConfig
13
+ from vectara_agentic.sub_query_workflow import SubQuestionQueryWorkflow
14
 
15
  from dotenv import load_dotenv
16
  load_dotenv(override=True)
 
37
  initial_prompt = "How can I help you today?"
38
 
39
  # Tool to get the income statement for a given company and year using the FMP API
40
+ @lru_cache(maxsize=256)
41
  def fmp_income_statement(
42
  ticker: str = Field(description="the ticker symbol of the company.", examples=["AAPL", "GOOG", "AMZN"]),
43
  year: int = Field(description="the year for which to get the income statement.", examples=[2020, 2021, 2022]),
 
52
  A dictionary with the income statement data.
53
  All data is in USD, but you can convert it to more compact form like K, M, B.
54
  """
55
+ if ticker not in tickers or year not in years:
56
+ return "Invalid ticker or year. Please call this tool with a valid company ticker and year."
57
  fmp_api_key = os.environ.get("FMP_API_KEY", None)
58
  if fmp_api_key is None:
59
  return "FMP_API_KEY environment variable not set. This tool does not work."
 
70
  ]
71
  values_dict = income_statement_specific_year.to_dict(orient="records")[0]
72
  return f"Financial results: {', '.join([f'{key}={value}' for key, value in values_dict.items() if key not in ['date', 'cik', 'link', 'finalLink']])}"
73
+
74
  return f"FMP API returned error {response.status_code}. This tool does not work."
75
 
76
+ def get_company_info() -> list[str]:
77
+ """
78
+ Returns a dictionary of companies you can query about. Always check this before using any other tool.
79
+ The output is a dictionary of valid ticker symbols mapped to company names.
80
+ You can use this to identify the companies you can query about, and their ticker information.
81
+ """
82
+ return tickers
83
 
84
+ def get_valid_years() -> list[str]:
85
+ """
86
+ Returns a list of the years for which financial reports are available.
87
+ Always check this before using any other tool.
88
+ """
89
+ return years
90
+
91
+
92
+ class AgentTools:
93
+ def __init__(self, _cfg, agent_config):
94
+ self.tools_factory = ToolsFactory()
95
+ self.agent_config = agent_config
96
+ self.cfg = _cfg
97
+ self.vec_factory = VectaraToolFactory(vectara_api_key=_cfg.api_key,
98
+ vectara_corpus_key=_cfg.corpus_key)
99
+
100
+ def get_tools(self):
101
+ class QueryTranscriptsArgs(BaseModel):
102
+ query: str = Field(..., description="The user query, always in the form of a question", examples=["what are the risks reported?", "who are the competitors?"])
103
+ year: int | str = Field(
104
+ default=None,
105
+ description=f"The year this query relates to. An integer between {min(years)} and {max(years)} or a string specifying a condition on the year",
106
+ examples=[2020, '>2021', '<2023', '>=2021', '<=2023', '[2021, 2023]', '[2021, 2023)']
107
+ )
108
+ ticker: str = Field(..., description=f"The company ticker this query relates to. Must be a valid ticket symbol from the list {list(tickers.keys())}.")
109
+
110
+ vec_factory = VectaraToolFactory(vectara_api_key=self.cfg.api_key,
111
+ vectara_corpus_key=self.cfg.corpus_key)
112
+ summarizer = 'vectara-summary-table-md-query-ext-jan-2025-gpt-4o'
113
+ ask_transcripts = vec_factory.create_rag_tool(
114
+ tool_name = "ask_transcripts",
115
+ tool_description = """
116
+ Given a company name and year, responds to a user question about the company, based on analyst call transcripts about the company's financial reports for that year.
117
+ You can ask this tool any question about the company including risks, opportunities, financial performance, competitors and more.
118
+ """,
119
+ tool_args_schema = QueryTranscriptsArgs,
120
+ reranker = "multilingual_reranker_v1", rerank_k = 100, rerank_cutoff = 0.1,
121
+ n_sentences_before = 2, n_sentences_after = 4, lambda_val = 0.005,
122
+ summary_num_results = 15,
123
+ vectara_summarizer = summarizer,
124
+ include_citations = True,
125
+ verbose=False,
126
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ class SearchTranscriptsArgs(BaseModel):
129
+ query: str = Field(..., description="The user query, always in the form of a question", examples=["what are the risks reported?", "who are the competitors?"])
130
+ top_k: int = Field(..., description="The number of results to return.")
131
+ year: int | str = Field(
132
+ default=None,
133
+ description=f"The year this query relates to. An integer between {min(years)} and {max(years)} or a string specifying a condition on the year",
134
+ examples=[2020, '>2021', '<2023', '>=2021', '<=2023', '[2021, 2023]', '[2021, 2023)']
135
+ )
136
+ ticker: str = Field(..., description=f"The company ticker this query relates to. Must be a valid ticket symbol from the list {list(tickers.keys())}.")
137
+ search_transcripts = vec_factory.create_search_tool(
138
+ tool_name = "search_transcripts",
139
+ tool_description = """
140
+ Given a company name and year, and a user query, retrieves relevant documents about the company.
141
+ """,
142
+ tool_args_schema = SearchTranscriptsArgs,
143
+ reranker = "multilingual_reranker_v1", rerank_k = 100,
144
+ lambda_val = 0.005,
145
+ verbose=False
146
  )
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ tools_factory = ToolsFactory()
149
+ return (
150
+ [tools_factory.create_tool(tool) for tool in
151
+ [
152
+ get_company_info,
153
+ get_valid_years,
154
+ fmp_income_statement,
155
+ ]
156
+ ] +
157
+ [ask_transcripts, search_transcripts]
158
+ )
 
159
 
160
  def initialize_agent(_cfg, agent_progress_callback=None):
161
  financial_bot_instructions = """
162
  - You are a helpful financial assistant, with expertise in financial reporting, in conversation with a user.
163
+ - Use the 'fmp_income_statement' tool (with the company ticker and year) to obtain financial data.
 
 
 
 
164
  - Always check the 'get_company_info' and 'get_valid_years' tools to validate company and year are valid.
165
+ - Use the 'ask_transcripts' tool to answer most questions about the company's financial performance, risks, opportunities, strategy, competitors, and more.
166
  - Respond in a compact format by using appropriate units of measure (e.g., K for thousands, M for millions, B for billions).
167
  Do not report the same number twice (e.g. $100K and 100,000 USD).
168
  - Do not include URLs unless they are provided in the output of a tool you use.
 
172
  def query_logging(query: str, response: str):
173
  print(f"Logging query={query}, response={response}")
174
 
175
+ agent_config = AgentConfig()
176
+
177
  agent = Agent(
178
+ tools=AgentTools(_cfg, agent_config).get_tools(),
179
  topic="Financial data, annual reports and 10-K filings",
180
  custom_instructions=financial_bot_instructions,
181
  agent_progress_callback=agent_progress_callback,
182
  query_logging_callback=query_logging,
183
+ verbose=True,
184
+ #workflow_cls=SubQuestionQueryWorkflow,
185
  )
186
+
187
  agent.report()
188
  return agent
189
 
 
190
  def get_agent_config() -> OmegaConf:
191
  companies = ", ".join(tickers.values())
192
  cfg = OmegaConf.create({
requirements.txt CHANGED
@@ -6,4 +6,4 @@ streamlit_feedback==0.1.3
6
  uuid==1.30
7
  langdetect==1.0.9
8
  langcodes==3.4.0
9
- vectara-agentic==0.2.1
 
6
  uuid==1.30
7
  langdetect==1.0.9
8
  langcodes==3.4.0
9
+ vectara-agentic==0.2.5
st_app.py CHANGED
@@ -19,6 +19,10 @@ def format_log_msg(log_msg: str):
19
 
20
  def agent_progress_callback(status_type: AgentStatusType, msg: str):
21
  output = f'<span style="color:blue;">{status_type.value}</span>: {msg}'
 
 
 
 
22
  st.session_state.log_messages.append(output)
23
  if 'status' in st.session_state:
24
  latest_message = ''
@@ -140,6 +144,9 @@ async def launch_bot():
140
  with st.chat_message("assistant", avatar='🤖'):
141
  st.session_state.status = st.status('Processing...', expanded=False)
142
  response = st.session_state.agent.chat(st.session_state.prompt)
 
 
 
143
  res = escape_dollars_outside_latex(response.response)
144
 
145
  #response = await st.session_state.agent.achat(st.session_state.prompt)
 
19
 
20
  def agent_progress_callback(status_type: AgentStatusType, msg: str):
21
  output = f'<span style="color:blue;">{status_type.value}</span>: {msg}'
22
+ if "log_messages" not in st.session_state:
23
+ st.session_state.log_messages = [output]
24
+ else:
25
+ st.session_state.log_messages.append(output)
26
  st.session_state.log_messages.append(output)
27
  if 'status' in st.session_state:
28
  latest_message = ''
 
144
  with st.chat_message("assistant", avatar='🤖'):
145
  st.session_state.status = st.status('Processing...', expanded=False)
146
  response = st.session_state.agent.chat(st.session_state.prompt)
147
+
148
+ # from vectara_agentic.sub_query_workflow import SubQuestionQueryWorkflow
149
+ # response = await st.session_state.agent.run(inputs=SubQuestionQueryWorkflow.InputsModel(query=st.session_state.prompt))
150
  res = escape_dollars_outside_latex(response.response)
151
 
152
  #response = await st.session_state.agent.achat(st.session_state.prompt)