karthikeyan-r commited on
Commit
5f762bd
·
verified ·
1 Parent(s): c2d1e89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -8,8 +8,8 @@ st.set_page_config(page_title="Chat", layout="wide")
8
  # Sidebar: Model controls
9
  st.sidebar.title("Model Controls")
10
  model_options = {
11
- "1": "karthikeyan-r/slm-custom-model_6k",
12
- "2": "karthikeyan-r/calculation_model"
13
  }
14
 
15
  model_choice = st.sidebar.selectbox("Select Model", options=list(model_options.values()))
@@ -38,17 +38,7 @@ if load_model_button:
38
  try:
39
  # Load the selected model
40
  if model_choice == model_options["1"]:
41
- # Load the T5 model for general QA (slm-custom-model_6k)
42
- device = 0 if torch.cuda.is_available() else -1
43
- st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
44
- st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
45
- st.session_state["qa_pipeline"] = pipeline(
46
- "text2text-generation",
47
- model=st.session_state["model"],
48
- tokenizer=st.session_state["tokenizer"],
49
- device=device
50
- )
51
- elif model_choice == model_options["2"]:
52
  # Load the calculation model (calculation_model)
53
  tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
54
  model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
@@ -69,6 +59,18 @@ if load_model_button:
69
  st.session_state["model"] = model
70
  st.session_state["tokenizer"] = tokenizer
71
  st.session_state["qa_pipeline"] = None # Calculation model doesn't use text2text pipeline
 
 
 
 
 
 
 
 
 
 
 
 
72
  st.success("Model loaded successfully and ready!")
73
  except Exception as e:
74
  st.error(f"Error loading model: {e}")
@@ -120,7 +122,7 @@ if st.session_state["qa_pipeline"]:
120
  st.error(f"Error generating response: {e}")
121
  else:
122
  # Handle user input for the calculation model (calculation_model)
123
- if st.session_state["model"] and model_choice == model_options["2"]:
124
  user_input = st.text_input(
125
  "Enter your query for calculation:",
126
  value=st.session_state["user_input"],
 
8
  # Sidebar: Model controls
9
  st.sidebar.title("Model Controls")
10
  model_options = {
11
+ "1": "karthikeyan-r/calculation_model_11k",
12
+ "2": "karthikeyan-r/slm-custom-model_6k"
13
  }
14
 
15
  model_choice = st.sidebar.selectbox("Select Model", options=list(model_options.values()))
 
38
  try:
39
  # Load the selected model
40
  if model_choice == model_options["1"]:
41
+
 
 
 
 
 
 
 
 
 
 
42
  # Load the calculation model (calculation_model)
43
  tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
44
  model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
 
59
  st.session_state["model"] = model
60
  st.session_state["tokenizer"] = tokenizer
61
  st.session_state["qa_pipeline"] = None # Calculation model doesn't use text2text pipeline
62
+
63
+ elif model_choice == model_options["2"]:
64
+ # Load the T5 model for general QA (slm-custom-model_6k)
65
+ device = 0 if torch.cuda.is_available() else -1
66
+ st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
67
+ st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
68
+ st.session_state["qa_pipeline"] = pipeline(
69
+ "text2text-generation",
70
+ model=st.session_state["model"],
71
+ tokenizer=st.session_state["tokenizer"],
72
+ device=device
73
+ )
74
  st.success("Model loaded successfully and ready!")
75
  except Exception as e:
76
  st.error(f"Error loading model: {e}")
 
122
  st.error(f"Error generating response: {e}")
123
  else:
124
  # Handle user input for the calculation model (calculation_model)
125
+ if st.session_state["model"] and model_choice == model_options["1"]:
126
  user_input = st.text_input(
127
  "Enter your query for calculation:",
128
  value=st.session_state["user_input"],