File size: 3,053 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from pathlib import Path

import pandas as pd
from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
from langchain.tools.retriever import create_retriever_tool
from langchain_community.chat_models import ChatOpenAI
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_experimental.tools import PythonAstREPLTool

MAIN_DIR = Path(__file__).parents[1]

pd.set_option("display.max_rows", 20)
pd.set_option("display.max_columns", 20)

embedding_model = OpenAIEmbeddings()
vectorstore = FAISS.load_local(MAIN_DIR / "titanic_data", embedding_model)
retriever_tool = create_retriever_tool(
    vectorstore.as_retriever(), "person_name_search", "Search for a person by name"
)


TEMPLATE = """You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
It is important to understand the attributes of the dataframe before working with it. This is the result of running `df.head().to_markdown()`

<df>
{dhead}
</df>

You are not meant to use only these rows to answer questions - they are meant as a way of telling you about the shape and schema of the dataframe.
You also do not have use only the information here to answer questions - you can run intermediate queries to do exporatory data analysis to give you more information as needed.

You have a tool called `person_name_search` through which you can lookup a person by name and find the records corresponding to people with similar name as the query.
You should only really use this if your search term contains a persons name. Otherwise, try to solve it with code.

For example:

<question>How old is Jane?</question>
<logic>Use `person_name_search` since you can use the query `Jane`</logic>

<question>Who has id 320</question>
<logic>Use `python_repl` since even though the question is about a person, you don't know their name so you can't include it.</logic>
"""  # noqa: E501


class PythonInputs(BaseModel):
    query: str = Field(description="code snippet to run")


df = pd.read_csv(MAIN_DIR / "titanic.csv")
template = TEMPLATE.format(dhead=df.head().to_markdown())

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", template),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
        ("human", "{input}"),
    ]
)

repl = PythonAstREPLTool(
    locals={"df": df},
    name="python_repl",
    description="Runs code and returns the output of the final line",
    args_schema=PythonInputs,
)
tools = [repl, retriever_tool]
agent = OpenAIFunctionsAgent(
    llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=prompt, tools=tools
)
agent_executor = AgentExecutor(
    agent=agent, tools=tools, max_iterations=5, early_stopping_method="generate"
) | (lambda x: x["output"])

# Typing for playground inputs


class AgentInputs(BaseModel):
    input: str


agent_executor = agent_executor.with_types(input_type=AgentInputs)