Spaces:
Running
on
L4
Running
on
L4
Commit
·
fd97c8c
1
Parent(s):
2b27faa
add 3B model
Browse files- app.py +20 -9
- utils/llama_utils.py +5 -5
app.py
CHANGED
@@ -35,6 +35,7 @@ st.markdown("""
|
|
35 |
# ---------------------------------------
|
36 |
base_path = "data/"
|
37 |
base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
|
38 |
adapter_path = "./LLaMA-TOMMI-1.0/"
|
39 |
|
40 |
st.title(":red[AI University] :gray[/] FEM")
|
@@ -115,12 +116,12 @@ with st.sidebar:
|
|
115 |
# Choose the LLM model
|
116 |
st.session_state.synthesis_model = st.selectbox(
|
117 |
"Choose the LLM model",
|
118 |
-
["LLaMA-3.2-
|
119 |
index=1,
|
120 |
key='a2model'
|
121 |
)
|
122 |
|
123 |
-
if st.session_state.synthesis_model
|
124 |
synthesis_do_sample = st.toggle("Enable Sampling", value=False, key='synthesis_sample')
|
125 |
|
126 |
if synthesis_do_sample:
|
@@ -169,6 +170,14 @@ with col2:
|
|
169 |
help=question_help
|
170 |
)
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
with st.spinner("Loading LLaMA-3.2-11B..."):
|
173 |
if "LLaMA-3.2-11B" in [st.session_state.expert_model, st.session_state.synthesis_model]:
|
174 |
if 'llama_model' not in st.session_state:
|
@@ -176,12 +185,12 @@ with st.spinner("Loading LLaMA-3.2-11B..."):
|
|
176 |
st.session_state.llama_model = llama_model
|
177 |
st.session_state.llama_tokenizer = llama_tokenizer
|
178 |
|
179 |
-
with st.spinner("Loading LLaMA-
|
180 |
-
if st.session_state.expert_model
|
181 |
-
if '
|
182 |
-
|
183 |
-
st.session_state.
|
184 |
-
st.session_state.
|
185 |
|
186 |
# Load YouTube and LaTeX data
|
187 |
text_data_YT, context_embeddings_YT = load_youtube_data(base_path, model_name, yt_chunk_tokens, yt_overlap_tokens)
|
@@ -264,6 +273,7 @@ if submit_button_placeholder.button("AI Answer", type="primary"):
|
|
264 |
model=model_,
|
265 |
tokenizer=tokenizer_,
|
266 |
messages=messages,
|
|
|
267 |
do_sample=expert_do_sample,
|
268 |
temperature=expert_temperature if expert_do_sample else None,
|
269 |
top_k=expert_top_k if expert_do_sample else None,
|
@@ -289,7 +299,7 @@ if submit_button_placeholder.button("AI Answer", type="primary"):
|
|
289 |
#-------------------------
|
290 |
# synthesis responses
|
291 |
#-------------------------
|
292 |
-
if st.session_state.synthesis_model
|
293 |
synthesis_prompt = f"""
|
294 |
Question:
|
295 |
{st.session_state.question}
|
@@ -311,6 +321,7 @@ if submit_button_placeholder.button("AI Answer", type="primary"):
|
|
311 |
model=st.session_state.llama_model,
|
312 |
tokenizer=st.session_state.llama_tokenizer,
|
313 |
messages=messages,
|
|
|
314 |
do_sample=synthesis_do_sample,
|
315 |
temperature=synthesis_temperature if synthesis_do_sample else None,
|
316 |
top_k=synthesis_top_k if synthesis_do_sample else None,
|
|
|
35 |
# ---------------------------------------
|
36 |
base_path = "data/"
|
37 |
base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
38 |
+
base_model_path_3B = "meta-llama/Llama-3.2-3B-Instruct"
|
39 |
adapter_path = "./LLaMA-TOMMI-1.0/"
|
40 |
|
41 |
st.title(":red[AI University] :gray[/] FEM")
|
|
|
116 |
# Choose the LLM model
|
117 |
st.session_state.synthesis_model = st.selectbox(
|
118 |
"Choose the LLM model",
|
119 |
+
["LLaMA-3.2-3B","gpt-4o-mini"], # "LLaMA-3.2-11B",
|
120 |
index=1,
|
121 |
key='a2model'
|
122 |
)
|
123 |
|
124 |
+
if st.session_state.synthesis_model in ["LLaMA-3.2-3B", "LLaMA-3.2-11B"]:
|
125 |
synthesis_do_sample = st.toggle("Enable Sampling", value=False, key='synthesis_sample')
|
126 |
|
127 |
if synthesis_do_sample:
|
|
|
170 |
help=question_help
|
171 |
)
|
172 |
|
173 |
+
with st.spinner("Loading LLaMA-TOMMI-1.0-11B..."):
|
174 |
+
if st.session_state.expert_model == "LLaMA-TOMMI-1.0-11B":
|
175 |
+
if 'tommi_model' not in st.session_state:
|
176 |
+
tommi_model, tommi_tokenizer = load_fine_tuned_model(adapter_path, base_model_path)
|
177 |
+
st.session_state.tommi_model = tommi_model
|
178 |
+
st.session_state.tommi_tokenizer = tommi_tokenizer
|
179 |
+
|
180 |
+
|
181 |
with st.spinner("Loading LLaMA-3.2-11B..."):
|
182 |
if "LLaMA-3.2-11B" in [st.session_state.expert_model, st.session_state.synthesis_model]:
|
183 |
if 'llama_model' not in st.session_state:
|
|
|
185 |
st.session_state.llama_model = llama_model
|
186 |
st.session_state.llama_tokenizer = llama_tokenizer
|
187 |
|
188 |
+
with st.spinner("Loading LLaMA-3.2-3B..."):
|
189 |
+
if "LLaMA-3.2-3B" in [st.session_state.expert_model, st.session_state.synthesis_model]:
|
190 |
+
if 'llama_model_3B' not in st.session_state:
|
191 |
+
llama_model_3B, llama_tokenizer_3B = load_base_model(base_model_path_3B)
|
192 |
+
st.session_state.llama_model_3B = llama_model_3B
|
193 |
+
st.session_state.llama_tokenizer_3B = llama_tokenizer_3B
|
194 |
|
195 |
# Load YouTube and LaTeX data
|
196 |
text_data_YT, context_embeddings_YT = load_youtube_data(base_path, model_name, yt_chunk_tokens, yt_overlap_tokens)
|
|
|
273 |
model=model_,
|
274 |
tokenizer=tokenizer_,
|
275 |
messages=messages,
|
276 |
+
tokenizer_max_length=500,
|
277 |
do_sample=expert_do_sample,
|
278 |
temperature=expert_temperature if expert_do_sample else None,
|
279 |
top_k=expert_top_k if expert_do_sample else None,
|
|
|
299 |
#-------------------------
|
300 |
# synthesis responses
|
301 |
#-------------------------
|
302 |
+
if st.session_state.synthesis_model in ["LLaMA-3.2-3B", "LLaMA-3.2-11B"]:
|
303 |
synthesis_prompt = f"""
|
304 |
Question:
|
305 |
{st.session_state.question}
|
|
|
321 |
model=st.session_state.llama_model,
|
322 |
tokenizer=st.session_state.llama_tokenizer,
|
323 |
messages=messages,
|
324 |
+
tokenizer_max_length=30000,
|
325 |
do_sample=synthesis_do_sample,
|
326 |
temperature=synthesis_temperature if synthesis_do_sample else None,
|
327 |
top_k=synthesis_top_k if synthesis_do_sample else None,
|
utils/llama_utils.py
CHANGED
@@ -93,16 +93,16 @@ def generate_response(
|
|
93 |
model: AutoModelForCausalLM,
|
94 |
tokenizer: PreTrainedTokenizerFast,
|
95 |
messages: list,
|
|
|
96 |
do_sample: bool = False,
|
97 |
-
temperature: float = 0.
|
98 |
top_k: int = 50,
|
99 |
top_p: float = 0.95,
|
100 |
num_beams: int = 1,
|
101 |
-
max_new_tokens: int =
|
102 |
) -> str:
|
103 |
"""
|
104 |
Runs inference on an LLM model.
|
105 |
-
|
106 |
Args:
|
107 |
model (AutoModelForCausalLM)
|
108 |
tokenizer (PreTrainedTokenizerFast)
|
@@ -124,7 +124,7 @@ def generate_response(
|
|
124 |
# Tokenize input
|
125 |
inputs = tokenizer(
|
126 |
input_text,
|
127 |
-
max_length=
|
128 |
truncation=True,
|
129 |
return_tensors="pt"
|
130 |
).to(model.device)
|
@@ -158,4 +158,4 @@ def generate_response(
|
|
158 |
|
159 |
response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
|
160 |
|
161 |
-
return response
|
|
|
93 |
model: AutoModelForCausalLM,
|
94 |
tokenizer: PreTrainedTokenizerFast,
|
95 |
messages: list,
|
96 |
+
tokenizer_max_length: int = 500,
|
97 |
do_sample: bool = False,
|
98 |
+
temperature: float = 0.1,
|
99 |
top_k: int = 50,
|
100 |
top_p: float = 0.95,
|
101 |
num_beams: int = 1,
|
102 |
+
max_new_tokens: int = 700
|
103 |
) -> str:
|
104 |
"""
|
105 |
Runs inference on an LLM model.
|
|
|
106 |
Args:
|
107 |
model (AutoModelForCausalLM)
|
108 |
tokenizer (PreTrainedTokenizerFast)
|
|
|
124 |
# Tokenize input
|
125 |
inputs = tokenizer(
|
126 |
input_text,
|
127 |
+
max_length=tokenizer_max_length,
|
128 |
truncation=True,
|
129 |
return_tensors="pt"
|
130 |
).to(model.device)
|
|
|
158 |
|
159 |
response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
|
160 |
|
161 |
+
return response
|