File size: 5,034 Bytes
19b1388 3142288 19b1388 3142288 91d5de3 3142288 19b1388 2c7b8b6 7182eb1 19b1388 7182eb1 19b1388 91d5de3 7c6eab8 19b1388 91d5de3 19b1388 7182eb1 19b1388 2c7b8b6 91d5de3 2c7b8b6 19b1388 91d5de3 19b1388 7182eb1 19b1388 7182eb1 19b1388 2c7b8b6 19b1388 7182eb1 19b1388 91d5de3 19b1388 7182eb1 19b1388 29b611d 19b1388 7182eb1 2c7b8b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from pathlib import Path
import streamlit as st
from langchain import SQLDatabase
from langchain.agents import AgentType
from langchain.agents import initialize_agent, Tool
from langchain.callbacks import StreamlitCallbackHandler
from langchain.chains import LLMMathChain
from langchain.llms import OpenAI
from langchain.utilities import DuckDuckGoSearchAPIWrapper
from langchain.llms import OpenAI
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.chat_models import ChatOpenAI
from streamlit_agent.callbacks.capturing_callback_handler import playback_callbacks
from streamlit_agent.clear_results import with_clear_container
from chat2plot import chat2plot
from chat2plot.chat2plot import Plot
import pandas as pd
import sqlite3
import os
user_openai_api_key = os.environ.get('OPENAI_API_KEY')
# DB_PATH = (Path(__file__).parent / "Chinook.db").absolute()
DB_PATH = (Path(__file__).parent / "sitios2.sqlite").absolute()
SAVED_SESSIONS = {
"how many points are in field_id = 29?": "alanis.pickle",
"what is the proportion of points in point_type?": "alanis.pickle"
}
st.set_page_config(
page_title="MRKL", page_icon="🦜", layout="wide", initial_sidebar_state="collapsed"
)
"# Points and Samples"
# Setup credentials in Streamlit
# user_openai_api_key = st.sidebar.text_input(
# "OpenAI API Key", type="password", help="Set this to run your own custom questions."
# )
if user_openai_api_key:
openai_api_key = user_openai_api_key
enable_custom = True
else:
openai_api_key = "not_supplied"
enable_custom = False
conn = sqlite3.connect(DB_PATH)
df = pd.read_sql_query("SELECT ogc_fid, field_id, point_id, sample_id, label, sample_state, plot_type, depth_range_shallow_m, depth_range_deep_m, sample_timestamp FROM points_and_samples", conn)
conn.close()
# Tools setup
llm = OpenAI(temperature=0, openai_api_key=openai_api_key, streaming=True)
search = DuckDuckGoSearchAPIWrapper()
llm_math_chain = LLMMathChain.from_llm(llm)
db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
db_chain = SQLDatabaseChain.from_llm(llm, db)
c2p = chat2plot(df,chat=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125"))
tools = [
Tool(
name="Search",
func=search.run,
description="useful for when you need to answer questions about current events. You should ask targeted questions",
),
Tool(
name="Calculator",
func=llm_math_chain.run,
description="useful for when you need to answer questions about math",
),
Tool(
name="sitios piloto DB",
func=db_chain.run,
description="useful for when you need to answer questions about sitios piloto. Input should be in the form of a question containing full context",
),
Tool(
name="chat2plot",
func=c2p,
description="useful for when you need to create a plot from a table",
return_direct=True,
),
]
# Initialize agent
mrkl = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
# read sitios2.sqlite and convert it to dataframe
st.write(df)
with st.form(key="form"):
if not enable_custom:
"Ask one of the sample questions, or enter your API Key in the sidebar to ask your own custom questions."
prefilled = st.selectbox("Sample questions", sorted(SAVED_SESSIONS.keys())) or ""
mrkl_input = ""
if enable_custom:
# add text_input with prefilled using SAVED_SESSIONS
question = list(SAVED_SESSIONS.keys())[0]
user_input = st.text_input("Or, ask your own question")
if not user_input:
user_input = prefilled
submit_clicked = st.form_submit_button("Submit Question")
output_container = st.empty()
if with_clear_container(submit_clicked):
with output_container.container():
output_container.text("user")
st.write(user_input)
# output_container = output_container.container()
# output_container.chat_message("user").write(user_input)
with output_container.container():
output_container.text("assistant")
# st.write(answer_container)
answer_container = output_container.text("assistant")
st_callback = StreamlitCallbackHandler(answer_container)
# If we've saved this question, play it back instead of actually running LangChain
# (so that we don't exhaust our API calls unnecessarily)
# if user_input in SAVED_SESSIONS:
# session_name = SAVED_SESSIONS[user_input]
# session_path = Path(__file__).parent / "runs" / session_name
# print(f"Playing saved session: {session_path}")
# answer = playback_callbacks([st_callback], str(session_path), max_pause_time=2)
# else:
# answer = mrkl.run(user_input, callbacks=[st_callback])
answer = mrkl.run(user_input, callbacks=[st_callback])
print(type(answer))
if isinstance(answer, Plot):
result = answer
st.plotly_chart(result.figure)
else:
answer_container.write(answer)
|