File size: 5,034 Bytes
19b1388
 
 
 
 
 
 
 
3142288
19b1388
 
3142288
 
 
91d5de3
3142288
19b1388
 
 
2c7b8b6
 
 
 
 
 
7182eb1
19b1388
7182eb1
 
 
 
19b1388
 
91d5de3
7c6eab8
19b1388
 
 
 
 
 
91d5de3
19b1388
 
7182eb1
 
 
19b1388
 
 
 
 
 
 
 
2c7b8b6
 
91d5de3
2c7b8b6
 
 
19b1388
 
 
 
 
 
91d5de3
19b1388
 
 
 
 
 
 
 
 
 
 
 
7182eb1
19b1388
7182eb1
19b1388
2c7b8b6
 
 
 
 
 
19b1388
 
 
 
 
 
 
7182eb1
 
 
 
 
 
19b1388
 
 
 
 
 
 
91d5de3
 
19b1388
7182eb1
19b1388
 
 
 
 
29b611d
 
 
 
 
 
 
 
 
 
19b1388
 
 
7182eb1
 
 
 
 
 
 
 
2c7b8b6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from pathlib import Path

import streamlit as st

from langchain import SQLDatabase
from langchain.agents import AgentType
from langchain.agents import initialize_agent, Tool
from langchain.callbacks import StreamlitCallbackHandler
from langchain.chains import LLMMathChain
from langchain.llms import OpenAI
from langchain.utilities import DuckDuckGoSearchAPIWrapper
from langchain.llms import OpenAI
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.chat_models import ChatOpenAI


from streamlit_agent.callbacks.capturing_callback_handler import playback_callbacks
from streamlit_agent.clear_results import with_clear_container

from chat2plot import chat2plot
from chat2plot.chat2plot import Plot

import pandas as pd
import sqlite3
import os

user_openai_api_key = os.environ.get('OPENAI_API_KEY')

# DB_PATH = (Path(__file__).parent / "Chinook.db").absolute()
DB_PATH = (Path(__file__).parent / "sitios2.sqlite").absolute()

SAVED_SESSIONS = {
    "how many points are in field_id = 29?": "alanis.pickle",
    "what is the proportion of points in point_type?": "alanis.pickle"
}

st.set_page_config(
    page_title="MRKL", page_icon="🦜", layout="wide", initial_sidebar_state="collapsed"
)

"# Points and Samples"

# Setup credentials in Streamlit
# user_openai_api_key = st.sidebar.text_input(
#     "OpenAI API Key", type="password", help="Set this to run your own custom questions."
# )

if user_openai_api_key:
    openai_api_key = user_openai_api_key
    enable_custom = True
else:
    openai_api_key = "not_supplied"
    enable_custom = False


conn = sqlite3.connect(DB_PATH)
df = pd.read_sql_query("SELECT ogc_fid, field_id, point_id, sample_id, label, sample_state, plot_type, depth_range_shallow_m, depth_range_deep_m, sample_timestamp FROM points_and_samples", conn)
conn.close()


# Tools setup
llm = OpenAI(temperature=0, openai_api_key=openai_api_key, streaming=True)
search = DuckDuckGoSearchAPIWrapper()
llm_math_chain = LLMMathChain.from_llm(llm)
db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
db_chain = SQLDatabaseChain.from_llm(llm, db)
c2p = chat2plot(df,chat=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125"))
tools = [
    Tool(
        name="Search",
        func=search.run,
        description="useful for when you need to answer questions about current events. You should ask targeted questions",
    ),
    Tool(
        name="Calculator",
        func=llm_math_chain.run,
        description="useful for when you need to answer questions about math",
    ),
    Tool(
        name="sitios piloto DB",
        func=db_chain.run,
        description="useful for when you need to answer questions about sitios piloto. Input should be in the form of a question containing full context",
    ),
    Tool(
    name="chat2plot",
    func=c2p,
    description="useful for when you need to create a plot from a table",
    return_direct=True,
),
]

# Initialize agent
mrkl = initialize_agent(
    tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)

# read sitios2.sqlite and convert it to dataframe


st.write(df)


with st.form(key="form"):
    if not enable_custom:
        "Ask one of the sample questions, or enter your API Key in the sidebar to ask your own custom questions."
    prefilled = st.selectbox("Sample questions", sorted(SAVED_SESSIONS.keys())) or ""
    mrkl_input = ""

    if enable_custom:
        # add text_input with prefilled using SAVED_SESSIONS
        question = list(SAVED_SESSIONS.keys())[0]
        user_input = st.text_input("Or, ask your own question")
    if not user_input:
        user_input = prefilled
    submit_clicked = st.form_submit_button("Submit Question")

output_container = st.empty()
if with_clear_container(submit_clicked):
    with output_container.container():
        output_container.text("user")
        st.write(user_input)
    # output_container = output_container.container()
    # output_container.chat_message("user").write(user_input)
    with output_container.container():
        output_container.text("assistant")
        # st.write(answer_container)
        answer_container = output_container.text("assistant")
        st_callback = StreamlitCallbackHandler(answer_container)

    # If we've saved this question, play it back instead of actually running LangChain
    # (so that we don't exhaust our API calls unnecessarily)
    # if user_input in SAVED_SESSIONS:
    #     session_name = SAVED_SESSIONS[user_input]
    #     session_path = Path(__file__).parent / "runs" / session_name
    #     print(f"Playing saved session: {session_path}")
    #     answer = playback_callbacks([st_callback], str(session_path), max_pause_time=2)
    # else:
    #     answer = mrkl.run(user_input, callbacks=[st_callback])
    answer = mrkl.run(user_input, callbacks=[st_callback])
    print(type(answer))
    if isinstance(answer, Plot):
        result = answer
        st.plotly_chart(result.figure)
    else:
        answer_container.write(answer)