Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,8 +8,8 @@ st.set_page_config(page_title="Chat", layout="wide")
|
|
8 |
# Sidebar: Model controls
|
9 |
st.sidebar.title("Model Controls")
|
10 |
model_options = {
|
11 |
-
"1": "karthikeyan-r/
|
12 |
-
"2": "karthikeyan-r/
|
13 |
}
|
14 |
|
15 |
model_choice = st.sidebar.selectbox("Select Model", options=list(model_options.values()))
|
@@ -38,17 +38,7 @@ if load_model_button:
|
|
38 |
try:
|
39 |
# Load the selected model
|
40 |
if model_choice == model_options["1"]:
|
41 |
-
|
42 |
-
device = 0 if torch.cuda.is_available() else -1
|
43 |
-
st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
|
44 |
-
st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
|
45 |
-
st.session_state["qa_pipeline"] = pipeline(
|
46 |
-
"text2text-generation",
|
47 |
-
model=st.session_state["model"],
|
48 |
-
tokenizer=st.session_state["tokenizer"],
|
49 |
-
device=device
|
50 |
-
)
|
51 |
-
elif model_choice == model_options["2"]:
|
52 |
# Load the calculation model (calculation_model)
|
53 |
tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
|
54 |
model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
|
@@ -69,6 +59,18 @@ if load_model_button:
|
|
69 |
st.session_state["model"] = model
|
70 |
st.session_state["tokenizer"] = tokenizer
|
71 |
st.session_state["qa_pipeline"] = None # Calculation model doesn't use text2text pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
st.success("Model loaded successfully and ready!")
|
73 |
except Exception as e:
|
74 |
st.error(f"Error loading model: {e}")
|
@@ -120,7 +122,7 @@ if st.session_state["qa_pipeline"]:
|
|
120 |
st.error(f"Error generating response: {e}")
|
121 |
else:
|
122 |
# Handle user input for the calculation model (calculation_model)
|
123 |
-
if st.session_state["model"] and model_choice == model_options["
|
124 |
user_input = st.text_input(
|
125 |
"Enter your query for calculation:",
|
126 |
value=st.session_state["user_input"],
|
|
|
8 |
# Sidebar: Model controls
|
9 |
st.sidebar.title("Model Controls")
|
10 |
model_options = {
|
11 |
+
"1": "karthikeyan-r/calculation_model_11k",
|
12 |
+
"2": "karthikeyan-r/slm-custom-model_6k"
|
13 |
}
|
14 |
|
15 |
model_choice = st.sidebar.selectbox("Select Model", options=list(model_options.values()))
|
|
|
38 |
try:
|
39 |
# Load the selected model
|
40 |
if model_choice == model_options["1"]:
|
41 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
# Load the calculation model (calculation_model)
|
43 |
tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
|
44 |
model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
|
|
|
59 |
st.session_state["model"] = model
|
60 |
st.session_state["tokenizer"] = tokenizer
|
61 |
st.session_state["qa_pipeline"] = None # Calculation model doesn't use text2text pipeline
|
62 |
+
|
63 |
+
elif model_choice == model_options["2"]:
|
64 |
+
# Load the T5 model for general QA (slm-custom-model_6k)
|
65 |
+
device = 0 if torch.cuda.is_available() else -1
|
66 |
+
st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
|
67 |
+
st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
|
68 |
+
st.session_state["qa_pipeline"] = pipeline(
|
69 |
+
"text2text-generation",
|
70 |
+
model=st.session_state["model"],
|
71 |
+
tokenizer=st.session_state["tokenizer"],
|
72 |
+
device=device
|
73 |
+
)
|
74 |
st.success("Model loaded successfully and ready!")
|
75 |
except Exception as e:
|
76 |
st.error(f"Error loading model: {e}")
|
|
|
122 |
st.error(f"Error generating response: {e}")
|
123 |
else:
|
124 |
# Handle user input for the calculation model (calculation_model)
|
125 |
+
if st.session_state["model"] and model_choice == model_options["1"]:
|
126 |
user_input = st.text_input(
|
127 |
"Enter your query for calculation:",
|
128 |
value=st.session_state["user_input"],
|