mostafa-sh commited on
Commit
b849b51
·
1 Parent(s): 07c040d

add local model

Browse files
Files changed (2) hide show
  1. app.py +57 -13
  2. utils.py +141 -0
app.py CHANGED
@@ -6,6 +6,7 @@ from sentence_transformers import SentenceTransformer
6
  from openai import OpenAI
7
  import random
8
  import prompts
 
9
 
10
  st.set_page_config(page_title="AI University")
11
 
@@ -70,6 +71,10 @@ def fixed_knn_retrieval(question_embedding, context_embeddings, top_k=5, min_k=1
70
  def sec_to_time(start_time):
71
  return f"{start_time // 60:02}:{start_time % 60:02}"
72
 
 
 
 
 
73
  st.markdown("""
74
  <style>
75
  .video-wrapper {
@@ -161,22 +166,29 @@ with st.sidebar:
161
  # latex_overlap_tokens = latex_chunk_tokens // 4
162
  latex_overlap_tokens = 0
163
 
164
- st.write(' ')
165
- with st.expander('Expert model',expanded=False):
166
- # st.write('**Expert model**')
167
- # with st.container(border=True):
168
- # Choose the LLM model
169
 
170
- use_expert_answer = st.toggle("Use expert answer", value=True)
171
- show_expert_responce = st.toggle("Show initial expert answer", value=False)
172
 
173
- model = st.selectbox("Choose the LLM model", ["gpt-4o-mini", "gpt-3.5-turbo"], key='a1model')
174
 
175
- # Temperature
176
- expert_temperature = st.slider("Temperature", 0.0, 0.3, .2, help="Defines the randomness in the next token prediction. Lower: More predictable and focused. Higher: More adventurous and diverse.", key='a1t')
177
 
178
- expert_top_p = st.slider("Top P", 0.1, 0.3, 0.1, help="Defines the range of token choices the model can consider in the next prediction. Lower: More focused and restricted to high-probability options. Higher: More creative, allowing consideration of less likely options.", key='a1p')
179
-
 
 
 
 
 
 
 
 
 
 
180
 
181
  with st.expander('Synthesis model',expanded=False):
182
 
@@ -281,9 +293,41 @@ if submit_button_placeholder.button("AI Answer", type="primary"):
281
  context += context_item['text'] + '\n\n'
282
 
283
  if use_expert_answer:
284
- st.session_state.expert_answer = prompts.openai_domain_specific_answer_generation("Finite Element Method", st.session_state.question, model=model, temperature=expert_temperature, top_p=expert_top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  else:
286
  st.session_state.expert_answer = 'No Expert Answer. Only use the context.'
 
 
287
  answer = prompts.openai_context_integration("Finite Element Method", st.session_state.question, st.session_state.expert_answer, context, model=model, temperature=integration_temperature, top_p=integration_top_p)
288
 
289
  if answer.split()[0] == "NOT_ENOUGH_INFO":
 
6
  from openai import OpenAI
7
  import random
8
  import prompts
9
+ from utils import get_bnb_config, load_base_model, load_fine_tuned_model, generate_response
10
 
11
  st.set_page_config(page_title="AI University")
12
 
 
71
  def sec_to_time(start_time):
72
  return f"{start_time // 60:02}:{start_time % 60:02}"
73
 
74
+
75
+
76
+
77
+
78
  st.markdown("""
79
  <style>
80
  .video-wrapper {
 
166
  # latex_overlap_tokens = latex_chunk_tokens // 4
167
  latex_overlap_tokens = 0
168
 
169
+ st.write(' ')
170
+ with st.expander('Expert model', expanded=False):
 
 
 
171
 
172
+ use_expert_answer = st.toggle("Use expert answer", value=True)
173
+ show_expert_responce = st.toggle("Show initial expert answer", value=False)
174
 
175
+ model = st.selectbox("Choose the LLM model", ["gpt-4o-mini", "gpt-3.5-turbo", "llama-tommi-0.35"], key='a1model')
176
 
177
+ if model == "llama-tommi-0.35":
178
+ tommi_do_sample = st.toggle("Enable Sampling", value=True, key='tommi_sample')
179
 
180
+ if tommi_do_sample:
181
+ tommi_temperature = st.slider("Temperature", 0.0, 1.5, 0.7, key='tommi_temp')
182
+ tommi_top_k = st.slider("Top K", 0, 100, 50, key='tommi_top_k')
183
+ tommi_top_p = st.slider("Top P", 0.0, 1.0, 0.95, key='tommi_top_p')
184
+ else:
185
+ tommi_num_beams = st.slider("Num Beams", 1, 10, 4, key='tommi_num_beams')
186
+
187
+ tommi_max_new_tokens = st.slider("Max New Tokens", 100, 2000, 500, step=50, key='tommi_max_new_tokens')
188
+ else:
189
+ expert_temperature = st.slider("Temperature", 0.0, 1.5, 0.7, key='a1t')
190
+ expert_top_p = st.slider("Top P", 0.0, 1.0, 0.9, key='a1p')
191
+ expert_top_k = st.slider("Top K", 0, 100, 50, key='a1k')
192
 
193
  with st.expander('Synthesis model',expanded=False):
194
 
 
293
  context += context_item['text'] + '\n\n'
294
 
295
  if use_expert_answer:
296
+ if model == "llama-tommi-0.35":
297
+ if 'tommi_model' not in st.session_state:
298
+ tommi_model, tommi_tokenizer = load_fine_tuned_model(adapter_path, base_model_path)
299
+ st.session_state.tommi_model = tommi_model
300
+ st.session_state.tommi_tokenizer = tommi_tokenizer
301
+
302
+ messages = [
303
+ {"role": "system", "content": "You are an expert in Finite Element Methods."},
304
+ {"role": "user", "content": st.session_state.question}
305
+ ]
306
+
307
+ st.session_state.expert_answer = generate_response(
308
+ model=st.session_state.tommi_model,
309
+ tokenizer=st.session_state.tommi_tokenizer,
310
+ messages=messages,
311
+ do_sample=tommi_do_sample,
312
+ temperature=tommi_temperature if tommi_do_sample else None,
313
+ top_k=tommi_top_k if tommi_do_sample else None,
314
+ top_p=tommi_top_p if tommi_do_sample else None,
315
+ num_beams=tommi_num_beams if not tommi_do_sample else 1,
316
+ max_new_tokens=tommi_max_new_tokens
317
+ )
318
+ else:
319
+ st.session_state.expert_answer = prompts.openai_domain_specific_answer_generation(
320
+ "Finite Element Method",
321
+ st.session_state.question,
322
+ model=model,
323
+ temperature=expert_temperature,
324
+ top_p=expert_top_p,
325
+ top_k=expert_top_k
326
+ )
327
  else:
328
  st.session_state.expert_answer = 'No Expert Answer. Only use the context.'
329
+
330
+
331
  answer = prompts.openai_context_integration("Finite Element Method", st.session_state.question, st.session_state.expert_answer, context, model=model, temperature=integration_temperature, top_p=integration_top_p)
332
 
333
  if answer.split()[0] == "NOT_ENOUGH_INFO":
utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
3
+ from peft import PeftModel
4
+
5
+ #-----------------------------------------
6
+ # Quantization Config
7
+ #-----------------------------------------
8
+ def get_bnb_config():
9
+ return BitsAndBytesConfig(
10
+ load_in_4bit=True,
11
+ bnb_4bit_quant_type="nf4",
12
+ bnb_4bit_compute_dtype=torch.float16,
13
+ bnb_4bit_use_double_quant=True,
14
+ bnb_4bit_quant_storage=torch.float16
15
+ )
16
+
17
+ #-----------------------------------------
18
+ # Base Model Loader
19
+ #-----------------------------------------
20
+ def load_base_model(base_model_path: str):
21
+ """
22
+ Loads a base LLM model with 4-bit quantization and tokenizer.
23
+
24
+ Args:
25
+ base_model_path (str): HF model path
26
+
27
+ Returns:
28
+ model (AutoModelForCausalLM)
29
+ tokenizer (PreTrainedTokenizerFast)
30
+ """
31
+ bnb_config = get_bnb_config()
32
+
33
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")
34
+
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ base_model_path,
37
+ quantization_config=bnb_config,
38
+ trust_remote_code=True,
39
+ attn_implementation="eager",
40
+ torch_dtype=torch.float16
41
+ )
42
+
43
+ return model, tokenizer
44
+
45
+ #-----------------------------------------
46
+ # Fine-Tuned Model Loader
47
+ #-----------------------------------------
48
+ def load_fine_tuned_model(adapter_path: str, base_model_path: str):
49
+ """
50
+ Loads the fine-tuned model by applying LoRA adapter to a base model.
51
+
52
+ Args:
53
+ adapter_path (str): Local or HF adapter path
54
+ base_model_path (str): Base LLM model path
55
+
56
+ Returns:
57
+ fine_tuned_model (PeftModel)
58
+ tokenizer (PreTrainedTokenizerFast)
59
+ """
60
+ bnb_config = get_bnb_config()
61
+
62
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")
63
+
64
+ base_model = AutoModelForCausalLM.from_pretrained(
65
+ base_model_path,
66
+ quantization_config=bnb_config,
67
+ trust_remote_code=True,
68
+ attn_implementation="eager",
69
+ torch_dtype=torch.float16
70
+ )
71
+
72
+ fine_tuned_model = PeftModel.from_pretrained(
73
+ base_model,
74
+ adapter_path,
75
+ device_map="auto"
76
+ )
77
+
78
+ return fine_tuned_model, tokenizer
79
+
80
+ #-----------------------------------------
81
+ # Inference Function
82
+ #-----------------------------------------
83
+ @torch.no_grad()
84
+ def generate_response(
85
+ model: AutoModelForCausalLM,
86
+ tokenizer: PreTrainedTokenizerFast,
87
+ messages: list,
88
+ do_sample: bool = False,
89
+ temperature: float = 0.7,
90
+ top_k: int = 50,
91
+ top_p: float = 0.95,
92
+ num_beams: int = 1,
93
+ max_new_tokens: int = 500
94
+ ) -> str:
95
+ """
96
+ Runs inference on an LLM model.
97
+
98
+ Args:
99
+ model (AutoModelForCausalLM)
100
+ tokenizer (PreTrainedTokenizerFast)
101
+ messages (list): List of dicts containing 'role' and 'content'
102
+
103
+ Returns:
104
+ str: Model response
105
+ """
106
+ # Ensure pad token exists
107
+ tokenizer.pad_token = "<|reserved_special_token_5|>"
108
+
109
+ # Create chat prompt
110
+ input_text = tokenizer.apply_chat_template(
111
+ messages,
112
+ add_generation_prompt=True,
113
+ tokenize=False
114
+ )
115
+
116
+ # Tokenize input
117
+ inputs = tokenizer(
118
+ input_text,
119
+ max_length=500,
120
+ truncation=True,
121
+ return_tensors="pt"
122
+ ).to(model.device)
123
+
124
+ generation_params = {
125
+ "do_sample": do_sample,
126
+ "temperature": temperature if do_sample else None,
127
+ "top_k": top_k if do_sample else None,
128
+ "top_p": top_p if do_sample else None,
129
+ "num_beams": num_beams if not do_sample else 1,
130
+ "max_new_tokens": max_new_tokens
131
+ }
132
+
133
+ output = model.generate(**inputs, **generation_params)
134
+
135
+ # Decode and clean up response
136
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
137
+
138
+ if 'assistant' in response:
139
+ response = response.split('assistant')[1].strip()
140
+
141
+ return response