Spaces:
Sleeping
Sleeping
Issue fix
Browse files
app.py
CHANGED
@@ -55,16 +55,38 @@ available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
|
|
55 |
if "selected_model" not in st.session_state:
|
56 |
st.session_state["selected_model"] = available_models[0]
|
57 |
|
|
|
|
|
58 |
st.session_state["selected_model"] = st.sidebar.selectbox(
|
59 |
"Select LLM Model",
|
60 |
available_models,
|
61 |
index=available_models.index(st.session_state["selected_model"])
|
62 |
)
|
63 |
|
64 |
-
# Dynamically initialize LLM based on selection
|
65 |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
classifier = load_models()
|
70 |
|
|
|
55 |
if "selected_model" not in st.session_state:
|
56 |
st.session_state["selected_model"] = available_models[0]
|
57 |
|
58 |
+
previous_model = st.session_state.get("selected_model", available_models[0])
|
59 |
+
|
60 |
st.session_state["selected_model"] = st.sidebar.selectbox(
|
61 |
"Select LLM Model",
|
62 |
available_models,
|
63 |
index=available_models.index(st.session_state["selected_model"])
|
64 |
)
|
65 |
|
|
|
66 |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
67 |
+
|
68 |
+
if st.session_state["selected_model"] != previous_model:
|
69 |
+
if st.session_state.messages:
|
70 |
+
col1, col2 = st.columns([3,1])
|
71 |
+
with col1:
|
72 |
+
st.warning("Changing models will clear current conversation.")
|
73 |
+
with col2:
|
74 |
+
if st.button("Confirm Change"):
|
75 |
+
st.session_state.messages = []
|
76 |
+
st.session_state.current_image = None
|
77 |
+
st.session_state.llm = initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY)
|
78 |
+
st.session_state.rag_chain = load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY)
|
79 |
+
st.rerun()
|
80 |
+
if st.button("Cancel"):
|
81 |
+
st.session_state["selected_model"] = previous_model
|
82 |
+
st.rerun()
|
83 |
+
else:
|
84 |
+
# No messages yet, just switch
|
85 |
+
st.session_state.llm = initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY)
|
86 |
+
st.session_state.rag_chain = load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY)
|
87 |
+
|
88 |
+
llm = st.session_state.get("llm", initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY))
|
89 |
+
rag_chain = st.session_state.get("rag_chain", load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY))
|
90 |
|
91 |
classifier = load_models()
|
92 |
|