KeerthiVM commited on
Commit
2860bc8
·
1 Parent(s): 8bad604
Files changed (1) hide show
  1. app.py +26 -18
app.py CHANGED
@@ -65,28 +65,36 @@ st.session_state["selected_model"] = st.sidebar.selectbox(
65
 
66
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
67
 
 
 
 
68
  if st.session_state["selected_model"] != previous_model:
69
  if st.session_state.messages:
70
- col1, col2 = st.columns([3,1])
71
- with col1:
72
  st.warning("Changing models will clear current conversation.")
73
- with col2:
74
- if st.button("Confirm Change"):
75
- st.session_state.messages = []
76
- st.session_state.current_image = None
77
- st.session_state.llm = initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY)
78
- st.session_state.rag_chain = load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY)
79
- st.rerun()
80
- if st.button("Cancel"):
81
- st.session_state["selected_model"] = previous_model
82
- st.rerun()
 
83
  else:
84
- # No messages yet, just switch
85
- st.session_state.llm = initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY)
86
- st.session_state.rag_chain = load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY)
87
-
88
- llm = st.session_state.get("llm", initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY))
89
- rag_chain = st.session_state.get("rag_chain", load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY))
 
 
 
 
90
 
91
  classifier = load_models()
92
 
 
65
 
66
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
67
 
68
+ if "model_change_confirmed" not in st.session_state:
69
+ st.session_state.model_change_confirmed = False
70
+
71
  if st.session_state["selected_model"] != previous_model:
72
  if st.session_state.messages:
73
+ st.session_state.model_change_confirmed = False # Reset confirmation state
74
+ with st.sidebar:
75
  st.warning("Changing models will clear current conversation.")
76
+ col1, col2 = st.columns(2)
77
+ with col1:
78
+ if st.button("Confirm Change", key="confirm_model_change"):
79
+ st.session_state.messages = []
80
+ st.session_state.current_image = None
81
+ st.session_state.model_change_confirmed = True
82
+ st.rerun()
83
+ with col2:
84
+ if st.button("Cancel", key="cancel_model_change"):
85
+ st.session_state["selected_model"] = previous_model
86
+ st.rerun()
87
  else:
88
+ st.session_state.model_change_confirmed = True
89
+
90
+ if "model_change_confirmed" not in st.session_state or st.session_state.model_change_confirmed:
91
+ llm = initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY)
92
+ rag_chain = load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY)
93
+ st.session_state.llm = llm
94
+ st.session_state.rag_chain = rag_chain
95
+ else:
96
+ llm = st.session_state.get("llm", initialize_llm(previous_model, OPENAI_API_KEY))
97
+ rag_chain = st.session_state.get("rag_chain", load_rag_chain(previous_model, OPENAI_API_KEY))
98
 
99
  classifier = load_models()
100