Agent / streamlit_agent /mrkl_demo.py
Fer-geo's picture
change
7c6eab8
from pathlib import Path
import streamlit as st
from langchain import SQLDatabase
from langchain.agents import AgentType
from langchain.agents import initialize_agent, Tool
from langchain.callbacks import StreamlitCallbackHandler
from langchain.chains import LLMMathChain
from langchain.llms import OpenAI
from langchain.utilities import DuckDuckGoSearchAPIWrapper
from langchain.llms import OpenAI
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.chat_models import ChatOpenAI
from streamlit_agent.callbacks.capturing_callback_handler import playback_callbacks
from streamlit_agent.clear_results import with_clear_container
from chat2plot import chat2plot
from chat2plot.chat2plot import Plot
import pandas as pd
import sqlite3
import os
user_openai_api_key = os.environ.get('OPENAI_API_KEY')
# DB_PATH = (Path(__file__).parent / "Chinook.db").absolute()
DB_PATH = (Path(__file__).parent / "sitios2.sqlite").absolute()
SAVED_SESSIONS = {
"how many points are in field_id = 29?": "alanis.pickle",
"what is the proportion of points in point_type?": "alanis.pickle"
}
st.set_page_config(
page_title="MRKL", page_icon="🦜", layout="wide", initial_sidebar_state="collapsed"
)
"# Points and Samples"
# Setup credentials in Streamlit
# user_openai_api_key = st.sidebar.text_input(
# "OpenAI API Key", type="password", help="Set this to run your own custom questions."
# )
if user_openai_api_key:
openai_api_key = user_openai_api_key
enable_custom = True
else:
openai_api_key = "not_supplied"
enable_custom = False
conn = sqlite3.connect(DB_PATH)
df = pd.read_sql_query("SELECT ogc_fid, field_id, point_id, sample_id, label, sample_state, plot_type, depth_range_shallow_m, depth_range_deep_m, sample_timestamp FROM points_and_samples", conn)
conn.close()
# Tools setup
llm = OpenAI(temperature=0, openai_api_key=openai_api_key, streaming=True)
search = DuckDuckGoSearchAPIWrapper()
llm_math_chain = LLMMathChain.from_llm(llm)
db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
db_chain = SQLDatabaseChain.from_llm(llm, db)
c2p = chat2plot(df,chat=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125"))
tools = [
Tool(
name="Search",
func=search.run,
description="useful for when you need to answer questions about current events. You should ask targeted questions",
),
Tool(
name="Calculator",
func=llm_math_chain.run,
description="useful for when you need to answer questions about math",
),
Tool(
name="sitios piloto DB",
func=db_chain.run,
description="useful for when you need to answer questions about sitios piloto. Input should be in the form of a question containing full context",
),
Tool(
name="chat2plot",
func=c2p,
description="useful for when you need to create a plot from a table",
return_direct=True,
),
]
# Initialize agent
mrkl = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
# read sitios2.sqlite and convert it to dataframe
st.write(df)
with st.form(key="form"):
if not enable_custom:
"Ask one of the sample questions, or enter your API Key in the sidebar to ask your own custom questions."
prefilled = st.selectbox("Sample questions", sorted(SAVED_SESSIONS.keys())) or ""
mrkl_input = ""
if enable_custom:
# add text_input with prefilled using SAVED_SESSIONS
question = list(SAVED_SESSIONS.keys())[0]
user_input = st.text_input("Or, ask your own question")
if not user_input:
user_input = prefilled
submit_clicked = st.form_submit_button("Submit Question")
output_container = st.empty()
if with_clear_container(submit_clicked):
with output_container.container():
output_container.text("user")
st.write(user_input)
# output_container = output_container.container()
# output_container.chat_message("user").write(user_input)
with output_container.container():
output_container.text("assistant")
# st.write(answer_container)
answer_container = output_container.text("assistant")
st_callback = StreamlitCallbackHandler(answer_container)
# If we've saved this question, play it back instead of actually running LangChain
# (so that we don't exhaust our API calls unnecessarily)
# if user_input in SAVED_SESSIONS:
# session_name = SAVED_SESSIONS[user_input]
# session_path = Path(__file__).parent / "runs" / session_name
# print(f"Playing saved session: {session_path}")
# answer = playback_callbacks([st_callback], str(session_path), max_pause_time=2)
# else:
# answer = mrkl.run(user_input, callbacks=[st_callback])
answer = mrkl.run(user_input, callbacks=[st_callback])
print(type(answer))
if isinstance(answer, Plot):
result = answer
st.plotly_chart(result.figure)
else:
answer_container.write(answer)