File size: 4,757 Bytes
19b1388 2c7b8b6 7182eb1 19b1388 7182eb1 19b1388 7182eb1 2c7b8b6 19b1388 7182eb1 19b1388 7182eb1 19b1388 2c7b8b6 19b1388 2c7b8b6 19b1388 7182eb1 19b1388 7182eb1 19b1388 2c7b8b6 19b1388 7182eb1 19b1388 7182eb1 19b1388 29b611d 19b1388 7182eb1 2c7b8b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from pathlib import Path
import streamlit as st
from langchain import SQLDatabase
from langchain.agents import AgentType
from langchain.agents import initialize_agent, Tool
from langchain.callbacks import StreamlitCallbackHandler
from langchain.chains import LLMMathChain, SQLDatabaseChain
from langchain.llms import OpenAI
from langchain.utilities import DuckDuckGoSearchAPIWrapper
from streamlit_agent.callbacks.capturing_callback_handler import playback_callbacks
from streamlit_agent.clear_results import with_clear_container
from chat2plot import chat2plot
from chat2plot.chat2plot import Plot
import pandas as pd
import sqlite3
import os
user_openai_api_key = os.environ.get('OPENAI_API_KEY')
# DB_PATH = (Path(__file__).parent / "Chinook.db").absolute()
DB_PATH = (Path(__file__).parent / "sitios2.sqlite").absolute()
SAVED_SESSIONS = {
"Que provincias están representadas en los sitios pilotos?": "alanis.pickle",
"Cual es la superficie total de sitios piloto en Buenos Aires?": "alanis.pickle",
"Realiza un grafico de las areas por provincia de la tabla de sitios piloto": "alanis.pickle",
}
st.set_page_config(
page_title="MRKL", page_icon="🦜", layout="wide", initial_sidebar_state="collapsed"
)
"# Tabla Sitios Piloto"
# Setup credentials in Streamlit
# user_openai_api_key = st.sidebar.text_input(
# "OpenAI API Key", type="password", help="Set this to run your own custom questions."
# )
if user_openai_api_key:
openai_api_key = user_openai_api_key
enable_custom = True
else:
openai_api_key = "not_supplied"
enable_custom = False
conn = sqlite3.connect(DB_PATH)
df = pd.read_sql_query("SELECT spoat, ecorregion, sag, provincia, supha, area, departamento FROM sitiospilotojson", conn)
conn.close()
# Tools setup
llm = OpenAI(temperature=0, openai_api_key=openai_api_key, streaming=True)
search = DuckDuckGoSearchAPIWrapper()
llm_math_chain = LLMMathChain.from_llm(llm)
db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
db_chain = SQLDatabaseChain.from_llm(llm, db)
c2p = chat2plot(df)
tools = [
Tool(
name="Search",
func=search.run,
description="useful for when you need to answer questions about current events. You should ask targeted questions",
),
Tool(
name="Calculator",
func=llm_math_chain.run,
description="useful for when you need to answer questions about math",
),
Tool(
name="sitios piloto DB",
func=db_chain.run,
description="useful for when you need to answer questions about sitios piloto. Input should be in the form of a question containing full context",
),
Tool(
name="chat2plot",
func=c2p,
description="useful for when you need to create a plot from a table",
return_direct=True,
),
]
# Initialize agent
mrkl = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
# read sitios2.sqlite and convert it to dataframe
st.write(df)
with st.form(key="form"):
if not enable_custom:
"Ask one of the sample questions, or enter your API Key in the sidebar to ask your own custom questions."
prefilled = st.selectbox("Sample questions", sorted(SAVED_SESSIONS.keys())) or ""
mrkl_input = ""
if enable_custom:
user_input = st.text_input("Or, ask your own question")
if not user_input:
user_input = prefilled
submit_clicked = st.form_submit_button("Submit Question")
output_container = st.empty()
if with_clear_container(submit_clicked):
with output_container.container():
output_container.text("user")
st.write(user_input)
# output_container = output_container.container()
# output_container.chat_message("user").write(user_input)
with output_container.container():
output_container.text("assistant")
# st.write(answer_container)
answer_container = output_container.text("assistant")
st_callback = StreamlitCallbackHandler(answer_container)
# If we've saved this question, play it back instead of actually running LangChain
# (so that we don't exhaust our API calls unnecessarily)
# if user_input in SAVED_SESSIONS:
# session_name = SAVED_SESSIONS[user_input]
# session_path = Path(__file__).parent / "runs" / session_name
# print(f"Playing saved session: {session_path}")
# answer = playback_callbacks([st_callback], str(session_path), max_pause_time=2)
# else:
# answer = mrkl.run(user_input, callbacks=[st_callback])
answer = mrkl.run(user_input, callbacks=[st_callback])
print(type(answer))
if isinstance(answer, Plot):
result = answer
st.plotly_chart(result.figure)
else:
answer_container.write(answer)
|