mostafa-sh commited on
Commit
fd97c8c
·
1 Parent(s): 2b27faa

add 3B model

Browse files
Files changed (2) hide show
  1. app.py +20 -9
  2. utils/llama_utils.py +5 -5
app.py CHANGED
@@ -35,6 +35,7 @@ st.markdown("""
35
  # ---------------------------------------
36
  base_path = "data/"
37
  base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
 
38
  adapter_path = "./LLaMA-TOMMI-1.0/"
39
 
40
  st.title(":red[AI University] :gray[/] FEM")
@@ -115,12 +116,12 @@ with st.sidebar:
115
  # Choose the LLM model
116
  st.session_state.synthesis_model = st.selectbox(
117
  "Choose the LLM model",
118
- ["LLaMA-3.2-11B", "gpt-4o-mini"],
119
  index=1,
120
  key='a2model'
121
  )
122
 
123
- if st.session_state.synthesis_model == "LLaMA-3.2-11B":
124
  synthesis_do_sample = st.toggle("Enable Sampling", value=False, key='synthesis_sample')
125
 
126
  if synthesis_do_sample:
@@ -169,6 +170,14 @@ with col2:
169
  help=question_help
170
  )
171
 
 
 
 
 
 
 
 
 
172
  with st.spinner("Loading LLaMA-3.2-11B..."):
173
  if "LLaMA-3.2-11B" in [st.session_state.expert_model, st.session_state.synthesis_model]:
174
  if 'llama_model' not in st.session_state:
@@ -176,12 +185,12 @@ with st.spinner("Loading LLaMA-3.2-11B..."):
176
  st.session_state.llama_model = llama_model
177
  st.session_state.llama_tokenizer = llama_tokenizer
178
 
179
- with st.spinner("Loading LLaMA-TOMMI-1.0-11B..."):
180
- if st.session_state.expert_model == "LLaMA-TOMMI-1.0-11B":
181
- if 'tommi_model' not in st.session_state:
182
- tommi_model, tommi_tokenizer = load_fine_tuned_model(adapter_path, base_model_path)
183
- st.session_state.tommi_model = tommi_model
184
- st.session_state.tommi_tokenizer = tommi_tokenizer
185
 
186
  # Load YouTube and LaTeX data
187
  text_data_YT, context_embeddings_YT = load_youtube_data(base_path, model_name, yt_chunk_tokens, yt_overlap_tokens)
@@ -264,6 +273,7 @@ if submit_button_placeholder.button("AI Answer", type="primary"):
264
  model=model_,
265
  tokenizer=tokenizer_,
266
  messages=messages,
 
267
  do_sample=expert_do_sample,
268
  temperature=expert_temperature if expert_do_sample else None,
269
  top_k=expert_top_k if expert_do_sample else None,
@@ -289,7 +299,7 @@ if submit_button_placeholder.button("AI Answer", type="primary"):
289
  #-------------------------
290
  # synthesis responses
291
  #-------------------------
292
- if st.session_state.synthesis_model == "LLaMA-3.2-11B":
293
  synthesis_prompt = f"""
294
  Question:
295
  {st.session_state.question}
@@ -311,6 +321,7 @@ if submit_button_placeholder.button("AI Answer", type="primary"):
311
  model=st.session_state.llama_model,
312
  tokenizer=st.session_state.llama_tokenizer,
313
  messages=messages,
 
314
  do_sample=synthesis_do_sample,
315
  temperature=synthesis_temperature if synthesis_do_sample else None,
316
  top_k=synthesis_top_k if synthesis_do_sample else None,
 
35
  # ---------------------------------------
36
  base_path = "data/"
37
  base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
38
+ base_model_path_3B = "meta-llama/Llama-3.2-3B-Instruct"
39
  adapter_path = "./LLaMA-TOMMI-1.0/"
40
 
41
  st.title(":red[AI University] :gray[/] FEM")
 
116
  # Choose the LLM model
117
  st.session_state.synthesis_model = st.selectbox(
118
  "Choose the LLM model",
119
+ ["LLaMA-3.2-3B","gpt-4o-mini"], # "LLaMA-3.2-11B",
120
  index=1,
121
  key='a2model'
122
  )
123
 
124
+ if st.session_state.synthesis_model in ["LLaMA-3.2-3B", "LLaMA-3.2-11B"]:
125
  synthesis_do_sample = st.toggle("Enable Sampling", value=False, key='synthesis_sample')
126
 
127
  if synthesis_do_sample:
 
170
  help=question_help
171
  )
172
 
173
+ with st.spinner("Loading LLaMA-TOMMI-1.0-11B..."):
174
+ if st.session_state.expert_model == "LLaMA-TOMMI-1.0-11B":
175
+ if 'tommi_model' not in st.session_state:
176
+ tommi_model, tommi_tokenizer = load_fine_tuned_model(adapter_path, base_model_path)
177
+ st.session_state.tommi_model = tommi_model
178
+ st.session_state.tommi_tokenizer = tommi_tokenizer
179
+
180
+
181
  with st.spinner("Loading LLaMA-3.2-11B..."):
182
  if "LLaMA-3.2-11B" in [st.session_state.expert_model, st.session_state.synthesis_model]:
183
  if 'llama_model' not in st.session_state:
 
185
  st.session_state.llama_model = llama_model
186
  st.session_state.llama_tokenizer = llama_tokenizer
187
 
188
+ with st.spinner("Loading LLaMA-3.2-3B..."):
189
+ if "LLaMA-3.2-3B" in [st.session_state.expert_model, st.session_state.synthesis_model]:
190
+ if 'llama_model_3B' not in st.session_state:
191
+ llama_model_3B, llama_tokenizer_3B = load_base_model(base_model_path_3B)
192
+ st.session_state.llama_model_3B = llama_model_3B
193
+ st.session_state.llama_tokenizer_3B = llama_tokenizer_3B
194
 
195
  # Load YouTube and LaTeX data
196
  text_data_YT, context_embeddings_YT = load_youtube_data(base_path, model_name, yt_chunk_tokens, yt_overlap_tokens)
 
273
  model=model_,
274
  tokenizer=tokenizer_,
275
  messages=messages,
276
+ tokenizer_max_length=500,
277
  do_sample=expert_do_sample,
278
  temperature=expert_temperature if expert_do_sample else None,
279
  top_k=expert_top_k if expert_do_sample else None,
 
299
  #-------------------------
300
  # synthesis responses
301
  #-------------------------
302
+ if st.session_state.synthesis_model in ["LLaMA-3.2-3B", "LLaMA-3.2-11B"]:
303
  synthesis_prompt = f"""
304
  Question:
305
  {st.session_state.question}
 
321
  model=st.session_state.llama_model,
322
  tokenizer=st.session_state.llama_tokenizer,
323
  messages=messages,
324
+ tokenizer_max_length=30000,
325
  do_sample=synthesis_do_sample,
326
  temperature=synthesis_temperature if synthesis_do_sample else None,
327
  top_k=synthesis_top_k if synthesis_do_sample else None,
utils/llama_utils.py CHANGED
@@ -93,16 +93,16 @@ def generate_response(
93
  model: AutoModelForCausalLM,
94
  tokenizer: PreTrainedTokenizerFast,
95
  messages: list,
 
96
  do_sample: bool = False,
97
- temperature: float = 0.7,
98
  top_k: int = 50,
99
  top_p: float = 0.95,
100
  num_beams: int = 1,
101
- max_new_tokens: int = 500
102
  ) -> str:
103
  """
104
  Runs inference on an LLM model.
105
-
106
  Args:
107
  model (AutoModelForCausalLM)
108
  tokenizer (PreTrainedTokenizerFast)
@@ -124,7 +124,7 @@ def generate_response(
124
  # Tokenize input
125
  inputs = tokenizer(
126
  input_text,
127
- max_length=500,
128
  truncation=True,
129
  return_tensors="pt"
130
  ).to(model.device)
@@ -158,4 +158,4 @@ def generate_response(
158
 
159
  response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
160
 
161
- return response
 
93
  model: AutoModelForCausalLM,
94
  tokenizer: PreTrainedTokenizerFast,
95
  messages: list,
96
+ tokenizer_max_length: int = 500,
97
  do_sample: bool = False,
98
+ temperature: float = 0.1,
99
  top_k: int = 50,
100
  top_p: float = 0.95,
101
  num_beams: int = 1,
102
+ max_new_tokens: int = 700
103
  ) -> str:
104
  """
105
  Runs inference on an LLM model.
 
106
  Args:
107
  model (AutoModelForCausalLM)
108
  tokenizer (PreTrainedTokenizerFast)
 
124
  # Tokenize input
125
  inputs = tokenizer(
126
  input_text,
127
+ max_length=tokenizer_max_length,
128
  truncation=True,
129
  return_tensors="pt"
130
  ).to(model.device)
 
158
 
159
  response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
160
 
161
+ return response