|
import streamlit as st |
|
import os |
|
from typing import Annotated |
|
from typing_extensions import TypedDict |
|
from langchain_community.utilities import ArxivAPIWrapper, WikipediaAPIWrapper |
|
from langchain_community.tools import ArxivQueryRun, WikipediaQueryRun |
|
from langgraph.graph.message import add_messages |
|
from langgraph.graph import StateGraph, START, END |
|
from langchain_groq import ChatGroq |
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
|
|
|
arxiv_wrapper = ArxivAPIWrapper(top_k_results=1, doc_content_chars_max=300) |
|
arxiv_tool = ArxivQueryRun(api_wrapper=arxiv_wrapper) |
|
wiki_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=300) |
|
wiki_tool = WikipediaQueryRun(api_wrapper=wiki_wrapper) |
|
tools = [wiki_tool, arxiv_tool] |
|
|
|
|
|
class State(TypedDict): |
|
messages: Annotated[list, add_messages] |
|
|
|
|
|
@st.cache_resource |
|
def initialize_llm(): |
|
groq_api_key = os.getenv("GROQ_API_KEY") |
|
if not groq_api_key: |
|
st.error("Please set the GROQ_API_KEY environment variable.") |
|
st.stop() |
|
return ChatGroq(groq_api_key=groq_api_key, model_name="Gemma2-9b-It") |
|
|
|
llm = initialize_llm() |
|
llm_with_tools = llm.bind_tools(tools=tools) |
|
|
|
|
|
def chatbot(state: State): |
|
return {"messages": [llm_with_tools.invoke(state["messages"])]} |
|
|
|
|
|
@st.cache_resource |
|
def build_graph(): |
|
graph_builder = StateGraph(State) |
|
graph_builder.add_node("chatbot", chatbot) |
|
tool_node = ToolNode(tools=tools) |
|
graph_builder.add_node("tools", tool_node) |
|
graph_builder.add_conditional_edges("chatbot", tools_condition) |
|
graph_builder.add_edge("tools", "chatbot") |
|
graph_builder.add_edge(START, "chatbot") |
|
return graph_builder.compile() |
|
|
|
graph = build_graph() |
|
|
|
|
|
st.title("WIKXIV AI: Wikipedia and ArXiv Chatbot") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("What is your question?"): |
|
|
|
st.chat_message("user").markdown(prompt) |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
events = graph.stream( |
|
{"messages": [("user", prompt)]}, |
|
stream_mode="values" |
|
) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
for event in events: |
|
message = event["messages"][-1] |
|
full_response += message.content |
|
message_placeholder.markdown(full_response + "β") |
|
message_placeholder.markdown(full_response) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
|
|
st.sidebar.title("About") |
|
st.sidebar.info("This is a Streamlit app made by theaimart\nthat uses LangGraph to create a chatbot with access to Wikipedia and ArXiv tools.") |
|
|
|
|