KeerthiVM commited on
Commit
8bad604
·
1 Parent(s): 9905e70
Files changed (1) hide show
  1. app.py +25 -3
app.py CHANGED
@@ -55,16 +55,38 @@ available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
55
  if "selected_model" not in st.session_state:
56
  st.session_state["selected_model"] = available_models[0]
57
 
 
 
58
  st.session_state["selected_model"] = st.sidebar.selectbox(
59
  "Select LLM Model",
60
  available_models,
61
  index=available_models.index(st.session_state["selected_model"])
62
  )
63
 
64
- # Dynamically initialize LLM based on selection
65
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
66
- llm = initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY)
67
- rag_chain = load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  classifier = load_models()
70
 
 
55
  if "selected_model" not in st.session_state:
56
  st.session_state["selected_model"] = available_models[0]
57
 
58
+ previous_model = st.session_state.get("selected_model", available_models[0])
59
+
60
  st.session_state["selected_model"] = st.sidebar.selectbox(
61
  "Select LLM Model",
62
  available_models,
63
  index=available_models.index(st.session_state["selected_model"])
64
  )
65
 
 
66
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
67
+
68
+ if st.session_state["selected_model"] != previous_model:
69
+ if st.session_state.messages:
70
+ col1, col2 = st.columns([3,1])
71
+ with col1:
72
+ st.warning("Changing models will clear current conversation.")
73
+ with col2:
74
+ if st.button("Confirm Change"):
75
+ st.session_state.messages = []
76
+ st.session_state.current_image = None
77
+ st.session_state.llm = initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY)
78
+ st.session_state.rag_chain = load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY)
79
+ st.rerun()
80
+ if st.button("Cancel"):
81
+ st.session_state["selected_model"] = previous_model
82
+ st.rerun()
83
+ else:
84
+ # No messages yet, just switch
85
+ st.session_state.llm = initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY)
86
+ st.session_state.rag_chain = load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY)
87
+
88
+ llm = st.session_state.get("llm", initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY))
89
+ rag_chain = st.session_state.get("rag_chain", load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY))
90
 
91
  classifier = load_models()
92