Spaces:
Running
Running
Added Ollama and model option.
Browse filesRewrote search optimization query prompt.
- messages.py +37 -36
- search_agent.py +26 -10
messages.py
CHANGED
@@ -1,49 +1,50 @@
|
|
1 |
import json
|
2 |
from langchain.schema import SystemMessage, HumanMessage
|
3 |
-
|
4 |
def get_optimized_search_messages(query):
|
5 |
messages = [
|
6 |
SystemMessage(
|
7 |
content="""
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
"""
|
42 |
),
|
43 |
HumanMessage(
|
44 |
content=f"""
|
45 |
-
|
46 |
-
|
|
|
47 |
"""
|
48 |
),
|
49 |
]
|
|
|
1 |
import json
|
2 |
from langchain.schema import SystemMessage, HumanMessage
|
3 |
+
|
4 |
def get_optimized_search_messages(query):
|
5 |
messages = [
|
6 |
SystemMessage(
|
7 |
content="""
|
8 |
+
You are a serach query optimizer specialist.
|
9 |
+
Provide a better search query for web search engine to answer the given question, end the queries with ’**’
|
10 |
+
Tips:
|
11 |
+
Identify the key concepts in the question
|
12 |
+
Remove filler words like "how to", "what is", "I want to"
|
13 |
+
Removed style such as "in the style of", "engaging", "short", "long"
|
14 |
+
Remove lenght instruction (example: essay, article, letter, blog, post, blogpost, etc)
|
15 |
+
Keep it short, around 3-7 words total
|
16 |
+
Put the most important keywords first
|
17 |
+
Remove formatting instructions
|
18 |
+
Remove style instructions (exmaple: in the style of, engaging, short, long)
|
19 |
+
Remove lenght instruction (example: essay, article, letter, etc)
|
20 |
+
Example:
|
21 |
+
Question: How do I bake chocolate chip cookies from scratch?
|
22 |
+
Search query: chocolate chip cookies recipe from scratch**
|
23 |
+
Example:
|
24 |
+
Question: I would like you to show me a time line of Marie Curie life. Show results as a markdown table
|
25 |
+
Search query: Marie Curie timeline**
|
26 |
+
Example:
|
27 |
+
Question: I would like you to write a long article on nato vs russia. Use know geopolical frameworks.
|
28 |
+
Search query: geopolitics nato russia**
|
29 |
+
Example:
|
30 |
+
Question: Write a engaging linkedin post about Andrew Ng
|
31 |
+
Search query: Andrew Ng**
|
32 |
+
Example:
|
33 |
+
Question: Write a short artible about the solar system in the style of Carl Sagan
|
34 |
+
Search query: solar system**
|
35 |
+
Example:
|
36 |
+
Question: Should I use Kubernetes? Answer in the style of Gilfoyde from the TV show Silicon Valley
|
37 |
+
Search query: Kubernetes decision**
|
38 |
+
Example:
|
39 |
+
Question: biography of napoleon. include a table with the major events.
|
40 |
+
Search query: napoleon biography events**
|
41 |
"""
|
42 |
),
|
43 |
HumanMessage(
|
44 |
content=f"""
|
45 |
+
Provide a better search query for web search engine to answer the given question, provide only one search query and nothing else, end the queries with ’**’.
|
46 |
+
Question: {query}
|
47 |
+
Search query:
|
48 |
"""
|
49 |
),
|
50 |
]
|
search_agent.py
CHANGED
@@ -4,6 +4,7 @@ Usage:
|
|
4 |
search_agent.py
|
5 |
[--domain=domain]
|
6 |
[--provider=provider]
|
|
|
7 |
[--temperature=temp]
|
8 |
[--max_pages=num]
|
9 |
[--output=text]
|
@@ -15,8 +16,9 @@ Options:
|
|
15 |
--version Show version.
|
16 |
-d domain --domain=domain Limit search to a specific domain
|
17 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
18 |
-
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq) [default: openai]
|
19 |
-
-m
|
|
|
20 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
21 |
|
22 |
"""
|
@@ -35,6 +37,7 @@ from langchain.schema import SystemMessage, HumanMessage
|
|
35 |
from langchain.callbacks import LangChainTracer
|
36 |
from langchain_groq import ChatGroq
|
37 |
from langchain_openai import ChatOpenAI
|
|
|
38 |
from langchain_openai import OpenAIEmbeddings
|
39 |
from langchain_community.vectorstores.faiss import FAISS
|
40 |
from langchain_community.chat_models.bedrock import BedrockChat
|
@@ -47,28 +50,40 @@ from rich.rule import Rule
|
|
47 |
from rich.markdown import Markdown
|
48 |
|
49 |
|
50 |
-
def get_chat_llm(provider, temperature=0.0):
|
51 |
-
console.log(f"Using provider {provider} with temperature {temperature}")
|
52 |
match provider:
|
53 |
case 'bedrock':
|
|
|
|
|
54 |
chat_llm = BedrockChat(
|
55 |
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
|
56 |
-
model_id=
|
57 |
model_kwargs={"temperature": temperature },
|
58 |
)
|
59 |
case 'openai':
|
60 |
-
|
|
|
|
|
61 |
case 'groq':
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
case _:
|
64 |
raise ValueError(f"Unknown LLM provider {provider}")
|
|
|
|
|
65 |
return chat_llm
|
66 |
|
67 |
def optimize_search_query(query):
|
68 |
from messages import get_optimized_search_messages
|
69 |
-
messages = get_optimized_search_messages(query)
|
70 |
response = chat.invoke(messages, config={"callbacks": callbacks})
|
71 |
-
|
|
|
72 |
|
73 |
|
74 |
def get_sources(query, max_pages=10, domain=None):
|
@@ -219,13 +234,14 @@ if __name__ == '__main__':
|
|
219 |
arguments = docopt(__doc__, version='Search Agent 0.1')
|
220 |
|
221 |
provider = arguments["--provider"]
|
|
|
222 |
temperature = float(arguments["--temperature"])
|
223 |
domain=arguments["--domain"]
|
224 |
max_pages=arguments["--max_pages"]
|
225 |
output=arguments["--output"]
|
226 |
query = arguments["SEARCH_QUERY"]
|
227 |
|
228 |
-
chat = get_chat_llm(provider, temperature)
|
229 |
|
230 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
231 |
optimize_search_query = optimize_search_query(query)
|
|
|
4 |
search_agent.py
|
5 |
[--domain=domain]
|
6 |
[--provider=provider]
|
7 |
+
[--model=model]
|
8 |
[--temperature=temp]
|
9 |
[--max_pages=num]
|
10 |
[--output=text]
|
|
|
16 |
--version Show version.
|
17 |
-d domain --domain=domain Limit search to a specific domain
|
18 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
19 |
+
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama) [default: openai]
|
20 |
+
-m model --model=model Use a specific model
|
21 |
+
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
22 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
23 |
|
24 |
"""
|
|
|
37 |
from langchain.callbacks import LangChainTracer
|
38 |
from langchain_groq import ChatGroq
|
39 |
from langchain_openai import ChatOpenAI
|
40 |
+
from langchain_community.chat_models import ChatOllama
|
41 |
from langchain_openai import OpenAIEmbeddings
|
42 |
from langchain_community.vectorstores.faiss import FAISS
|
43 |
from langchain_community.chat_models.bedrock import BedrockChat
|
|
|
50 |
from rich.markdown import Markdown
|
51 |
|
52 |
|
53 |
+
def get_chat_llm(provider, model, temperature=0.0):
|
|
|
54 |
match provider:
|
55 |
case 'bedrock':
|
56 |
+
if(model == None):
|
57 |
+
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
58 |
chat_llm = BedrockChat(
|
59 |
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
|
60 |
+
model_id=model,
|
61 |
model_kwargs={"temperature": temperature },
|
62 |
)
|
63 |
case 'openai':
|
64 |
+
if(model == None):
|
65 |
+
model = "gpt-3.5-turbo"
|
66 |
+
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
|
67 |
case 'groq':
|
68 |
+
if(model == None):
|
69 |
+
model = 'mixtral-8x7b-32768'
|
70 |
+
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
71 |
+
case 'ollama':
|
72 |
+
if(model == None):
|
73 |
+
model = 'llam2'
|
74 |
+
chat_llm = ChatOllama(model=model, temperature=temperature)
|
75 |
case _:
|
76 |
raise ValueError(f"Unknown LLM provider {provider}")
|
77 |
+
|
78 |
+
console.log(f"Using {model} on {provider} with temperature {temperature}")
|
79 |
return chat_llm
|
80 |
|
81 |
def optimize_search_query(query):
|
82 |
from messages import get_optimized_search_messages
|
83 |
+
messages = get_optimized_search_messages(query)
|
84 |
response = chat.invoke(messages, config={"callbacks": callbacks})
|
85 |
+
optimized_search_query = response.content
|
86 |
+
return optimized_search_query.strip('"').strip("**")
|
87 |
|
88 |
|
89 |
def get_sources(query, max_pages=10, domain=None):
|
|
|
234 |
arguments = docopt(__doc__, version='Search Agent 0.1')
|
235 |
|
236 |
provider = arguments["--provider"]
|
237 |
+
model = arguments["--model"]
|
238 |
temperature = float(arguments["--temperature"])
|
239 |
domain=arguments["--domain"]
|
240 |
max_pages=arguments["--max_pages"]
|
241 |
output=arguments["--output"]
|
242 |
query = arguments["SEARCH_QUERY"]
|
243 |
|
244 |
+
chat = get_chat_llm(provider, model, temperature)
|
245 |
|
246 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
247 |
optimize_search_query = optimize_search_query(query)
|