|
from pathlib import Path |
|
|
|
import streamlit as st |
|
|
|
from langchain import SQLDatabase |
|
from langchain.agents import AgentType |
|
from langchain.agents import initialize_agent, Tool |
|
from langchain.callbacks import StreamlitCallbackHandler |
|
from langchain.chains import LLMMathChain |
|
from langchain.llms import OpenAI |
|
from langchain.utilities import DuckDuckGoSearchAPIWrapper |
|
from langchain.llms import OpenAI |
|
from langchain.sql_database import SQLDatabase |
|
from langchain_experimental.sql import SQLDatabaseChain |
|
from langchain.chat_models import ChatOpenAI |
|
|
|
|
|
from streamlit_agent.callbacks.capturing_callback_handler import playback_callbacks |
|
from streamlit_agent.clear_results import with_clear_container |
|
|
|
from chat2plot import chat2plot |
|
from chat2plot.chat2plot import Plot |
|
|
|
import pandas as pd |
|
import sqlite3 |
|
import os |
|
|
|
user_openai_api_key = os.environ.get('OPENAI_API_KEY') |
|
|
|
|
|
DB_PATH = (Path(__file__).parent / "sitios2.sqlite").absolute() |
|
|
|
SAVED_SESSIONS = { |
|
"how many points are in field_id = 29?": "alanis.pickle", |
|
"what is the proportion of points in point_type?": "alanis.pickle" |
|
} |
|
|
|
st.set_page_config( |
|
page_title="MRKL", page_icon="🦜", layout="wide", initial_sidebar_state="collapsed" |
|
) |
|
|
|
"# Points and Samples" |
|
|
|
|
|
|
|
|
|
|
|
|
|
if user_openai_api_key: |
|
openai_api_key = user_openai_api_key |
|
enable_custom = True |
|
else: |
|
openai_api_key = "not_supplied" |
|
enable_custom = False |
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
df = pd.read_sql_query("SELECT ogc_fid, field_id, point_id, sample_id, label, sample_state, plot_type, depth_range_shallow_m, depth_range_deep_m, sample_timestamp FROM points_and_samples", conn) |
|
conn.close() |
|
|
|
|
|
|
|
llm = OpenAI(temperature=0, openai_api_key=openai_api_key, streaming=True) |
|
search = DuckDuckGoSearchAPIWrapper() |
|
llm_math_chain = LLMMathChain.from_llm(llm) |
|
db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}") |
|
db_chain = SQLDatabaseChain.from_llm(llm, db) |
|
c2p = chat2plot(df,chat=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125")) |
|
tools = [ |
|
Tool( |
|
name="Search", |
|
func=search.run, |
|
description="useful for when you need to answer questions about current events. You should ask targeted questions", |
|
), |
|
Tool( |
|
name="Calculator", |
|
func=llm_math_chain.run, |
|
description="useful for when you need to answer questions about math", |
|
), |
|
Tool( |
|
name="sitios piloto DB", |
|
func=db_chain.run, |
|
description="useful for when you need to answer questions about sitios piloto. Input should be in the form of a question containing full context", |
|
), |
|
Tool( |
|
name="chat2plot", |
|
func=c2p, |
|
description="useful for when you need to create a plot from a table", |
|
return_direct=True, |
|
), |
|
] |
|
|
|
|
|
mrkl = initialize_agent( |
|
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True |
|
) |
|
|
|
|
|
|
|
|
|
st.write(df) |
|
|
|
|
|
with st.form(key="form"): |
|
if not enable_custom: |
|
"Ask one of the sample questions, or enter your API Key in the sidebar to ask your own custom questions." |
|
prefilled = st.selectbox("Sample questions", sorted(SAVED_SESSIONS.keys())) or "" |
|
mrkl_input = "" |
|
|
|
if enable_custom: |
|
|
|
question = list(SAVED_SESSIONS.keys())[0] |
|
user_input = st.text_input("Or, ask your own question") |
|
if not user_input: |
|
user_input = prefilled |
|
submit_clicked = st.form_submit_button("Submit Question") |
|
|
|
output_container = st.empty() |
|
if with_clear_container(submit_clicked): |
|
with output_container.container(): |
|
output_container.text("user") |
|
st.write(user_input) |
|
|
|
|
|
with output_container.container(): |
|
output_container.text("assistant") |
|
|
|
answer_container = output_container.text("assistant") |
|
st_callback = StreamlitCallbackHandler(answer_container) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
answer = mrkl.run(user_input, callbacks=[st_callback]) |
|
print(type(answer)) |
|
if isinstance(answer, Plot): |
|
result = answer |
|
st.plotly_chart(result.figure) |
|
else: |
|
answer_container.write(answer) |
|
|