sagar007 commited on
Commit
f87dcd8
·
verified ·
1 Parent(s): 1abfce8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -44
app.py CHANGED
@@ -4,63 +4,268 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLI
4
  from PIL import Image
5
  import logging
6
  import spaces
7
- import numpy as np
8
 
 
9
  logging.basicConfig(level=logging.INFO)
10
 
11
  class LLaVAPhiModel:
12
- def __init__(self, model_id="sagar007/Lava_phi"):
13
  self.device = "cuda"
14
  self.model_id = model_id
15
- logging.info("Initializing LLaVA-Phi model...")
16
 
 
17
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
18
  if self.tokenizer.pad_token is None:
19
  self.tokenizer.pad_token = self.tokenizer.eos_token
20
-
21
- self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
 
 
 
 
 
 
 
22
  self.history = []
23
  self.model = None
24
  self.clip = None
25
- self.projection = None
 
 
 
 
 
 
 
 
26
 
27
  @spaces.GPU
28
  def ensure_models_loaded(self):
29
- if not torch.cuda.is_available():
30
- raise RuntimeError("CUDA is not available. This model requires a GPU.")
31
  if self.model is None:
 
32
  from transformers import BitsAndBytesConfig
33
  quantization_config = BitsAndBytesConfig(
34
- load_in_8bit=True,
35
- bnb_8bit_compute_dtype=torch.float16,
36
- bnb_8bit_use_double_quant=False
37
  )
38
- self.model = AutoModelForCausalLM.from_pretrained(
39
- self.model_id,
40
- quantization_config=quantization_config,
41
- device_map="auto",
42
- torch_dtype=torch.bfloat16,
43
- trust_remote_code=True
44
- )
45
- self.model.config.pad_token_id = self.tokenizer.eos_token_id
46
- logging.info("Successfully loaded main model on GPU")
 
 
 
 
 
47
 
48
  if self.clip is None:
49
- self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
50
- logging.info("Successfully loaded CLIP model")
51
- embed_dim = self.model.config.hidden_size
52
- clip_dim = self.clip.config.projection_dim
53
- self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device)
 
 
 
54
 
55
- # Rest of your class (process_image, generate_response, etc.) remains unchanged
56
- # ... (omitted for brevity)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- def create_demo():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  try:
 
60
  model = LLaVAPhiModel()
61
 
62
- demo = gr.Blocks(css="footer {visibility: hidden}")
63
- with demo:
 
 
64
  gr.Markdown(
65
  """
66
  # LLaVA-Phi Demo (Optimized for Accuracy)
@@ -83,58 +288,88 @@ def create_demo():
83
 
84
  image = gr.Image(type="pil", label="Upload Image (Optional)")
85
 
86
- with gr.Accordion("Advanced Settings", open=False):
 
87
  gr.Markdown("Adjust these parameters to control hallucination tendency")
88
  temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
89
  top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
90
  top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
91
  rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
92
  update_params = gr.Button("Update Parameters")
 
 
 
 
 
 
 
 
93
 
94
  def respond(message, chat_history, image):
95
  if not message and image is None:
96
- return chat_history
97
 
98
- response = model.generate_response(message, image)
99
- chat_history.append((message, response))
100
- return "", chat_history
 
 
 
 
 
101
 
102
  def clear_chat():
103
  model.clear_history()
104
- return None, None
105
 
106
  def update_params_fn(temp, top_p, top_k, rep_penalty):
107
- return model.update_generation_params(temp, top_p, top_k, rep_penalty)
 
108
 
109
  submit.click(
110
  respond,
111
  [msg, chatbot, image],
112
- [msg, chatbot],
113
  )
114
 
115
  clear.click(
116
  clear_chat,
117
  None,
118
- [chatbot, image],
119
  )
120
 
121
  msg.submit(
122
  respond,
123
  [msg, chatbot, image],
124
- [msg, chatbot],
125
  )
126
 
127
  update_params.click(
128
  update_params_fn,
129
  [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
130
- None
131
  )
132
-
133
  return demo
134
  except Exception as e:
135
  logging.error(f"Error creating demo: {str(e)}")
136
  raise
137
 
138
  if __name__ == "__main__":
139
- demo = create_demo()
140
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from PIL import Image
5
  import logging
6
  import spaces
7
+ import numpy
8
 
9
+ # Setup logging
10
  logging.basicConfig(level=logging.INFO)
11
 
12
  class LLaVAPhiModel:
13
+ def __init__(self, model_id="microsoft/phi-1_5"): # Updated to match config
14
  self.device = "cuda"
15
  self.model_id = model_id
16
+ logging.info(f"Initializing LLaVA-Phi model with {model_id}...")
17
 
18
+ # Initialize tokenizer
19
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
20
  if self.tokenizer.pad_token is None:
21
  self.tokenizer.pad_token = self.tokenizer.eos_token
22
+
23
+ try:
24
+ # Use CLIPProcessor with the correct model name from config
25
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26
+ logging.info("Successfully loaded CLIP processor")
27
+ except Exception as e:
28
+ logging.error(f"Failed to load CLIP processor: {str(e)}")
29
+ self.processor = None
30
+
31
+ # Increase history length to retain more context
32
  self.history = []
33
  self.model = None
34
  self.clip = None
35
+
36
+ # Default generation parameters - can be updated from config
37
+ self.temperature = 0.3
38
+ self.top_p = 0.92
39
+ self.top_k = 50
40
+ self.repetition_penalty = 1.2
41
+
42
+ # Set max length from config
43
+ self.max_length = 512 # Default value, will be updated from config
44
 
45
  @spaces.GPU
46
  def ensure_models_loaded(self):
47
+ """Ensure models are loaded in GPU context"""
 
48
  if self.model is None:
49
+ # Use 4-bit quantization according to config
50
  from transformers import BitsAndBytesConfig
51
  quantization_config = BitsAndBytesConfig(
52
+ load_in_4bit=True, # Changed to match config
53
+ bnb_4bit_compute_dtype=torch.bfloat16, # Changed to bfloat16 to match config's mixed_precision
54
+ bnb_4bit_use_double_quant=False
55
  )
56
+
57
+ try:
58
+ self.model = AutoModelForCausalLM.from_pretrained(
59
+ self.model_id,
60
+ quantization_config=quantization_config,
61
+ device_map="auto",
62
+ torch_dtype=torch.bfloat16,
63
+ trust_remote_code=True
64
+ )
65
+ self.model.config.pad_token_id = self.tokenizer.eos_token_id
66
+ logging.info(f"Successfully loaded main model: {self.model_id}")
67
+ except Exception as e:
68
+ logging.error(f"Failed to load main model: {str(e)}")
69
+ raise
70
 
71
  if self.clip is None:
72
+ try:
73
+ # Load CLIP model from config
74
+ clip_model_name = "openai/clip-vit-base-patch32" # From config
75
+ self.clip = CLIPModel.from_pretrained(clip_model_name).to(self.device)
76
+ logging.info(f"Successfully loaded CLIP model: {clip_model_name}")
77
+ except Exception as e:
78
+ logging.error(f"Failed to load CLIP model: {str(e)}")
79
+ self.clip = None
80
 
81
+ def apply_lora_config(self, lora_params):
82
+ """Apply LoRA configuration to the model - to be called during training"""
83
+ from peft import LoraConfig, get_peft_model
84
+
85
+ lora_config = LoraConfig(
86
+ r=lora_params.get("r", 16),
87
+ lora_alpha=lora_params.get("lora_alpha", 32),
88
+ lora_dropout=lora_params.get("lora_dropout", 0.05),
89
+ target_modules=lora_params.get("target_modules", ["Wqkv", "out_proj"]),
90
+ bias="none",
91
+ task_type="CAUSAL_LM"
92
+ )
93
+
94
+ # Convert model to PEFT/LoRA model
95
+ self.model = get_peft_model(self.model, lora_config)
96
+ logging.info("Applied LoRA configuration to the model")
97
+ return self.model
98
 
99
+ @spaces.GPU(duration=120)
100
+ def generate_response(self, message, image=None):
101
+ try:
102
+ self.ensure_models_loaded()
103
+
104
+ # Prepare prompt based on whether we have an image
105
+ has_image = image is not None
106
+
107
+ # Process text input
108
+ if has_image:
109
+ # For image+text input
110
+ prompt = f"human: <image>\n{message}\ngpt:"
111
+
112
+ # Check if model has vision encoding capability
113
+ if not hasattr(self.model, "encode_image") and not hasattr(self.model, "get_vision_tower"):
114
+ logging.warning("Model doesn't have standard image encoding methods")
115
+ has_image = False
116
+ prompt = f"human: {message}\ngpt:"
117
+ else:
118
+ # For text-only input
119
+ prompt = f"human: {message}\ngpt:"
120
+
121
+ # Include previous conversation context
122
+ context = ""
123
+ for turn in self.history[-5:]: # Include 5 previous turns
124
+ context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
125
+
126
+ full_prompt = context + prompt
127
+
128
+ # Tokenize the input text
129
+ inputs = self.tokenizer(
130
+ full_prompt,
131
+ return_tensors="pt",
132
+ padding=True,
133
+ truncation=True,
134
+ max_length=self.max_length
135
+ )
136
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
137
+
138
+ # LLaVA-Phi specific image handling
139
+ if has_image:
140
+ try:
141
+ # Convert image to correct format
142
+ if isinstance(image, str):
143
+ image = Image.open(image)
144
+ elif isinstance(image, numpy.ndarray):
145
+ image = Image.fromarray(image)
146
+
147
+ # Ensure image is in RGB mode
148
+ if image.mode != 'RGB':
149
+ image = image.convert('RGB')
150
+
151
+ # Process the image with CLIP processor
152
+ image_inputs = self.processor(images=image, return_tensors="pt")
153
+ image_features = self.clip.get_image_features(
154
+ pixel_values=image_inputs.pixel_values.to(self.device)
155
+ )
156
+
157
+ # Some LLaVA models have a prepare_inputs_for_generation method
158
+ if hasattr(self.model, "prepare_inputs_for_generation"):
159
+ logging.info("Using model's prepare_inputs_for_generation for image handling")
160
+
161
+ # Generate with image context
162
+ with torch.no_grad():
163
+ outputs = self.model.generate(
164
+ **inputs,
165
+ max_new_tokens=256,
166
+ min_length=20,
167
+ temperature=self.temperature,
168
+ do_sample=True,
169
+ top_p=self.top_p,
170
+ top_k=self.top_k,
171
+ repetition_penalty=self.repetition_penalty,
172
+ no_repeat_ngram_size=3,
173
+ use_cache=True,
174
+ pad_token_id=self.tokenizer.pad_token_id,
175
+ eos_token_id=self.tokenizer.eos_token_id
176
+ )
177
+
178
+ except Exception as e:
179
+ logging.error(f"Error handling image: {str(e)}")
180
+ # Fall back to text-only generation
181
+ logging.info("Falling back to text-only generation")
182
+ with torch.no_grad():
183
+ outputs = self.model.generate(
184
+ **inputs,
185
+ max_new_tokens=256,
186
+ min_length=20,
187
+ temperature=self.temperature,
188
+ do_sample=True,
189
+ top_p=self.top_p,
190
+ top_k=self.top_k,
191
+ repetition_penalty=self.repetition_penalty,
192
+ no_repeat_ngram_size=3,
193
+ use_cache=True,
194
+ pad_token_id=self.tokenizer.pad_token_id,
195
+ eos_token_id=self.tokenizer.eos_token_id
196
+ )
197
+ else:
198
+ # Text-only generation
199
+ with torch.no_grad():
200
+ outputs = self.model.generate(
201
+ **inputs,
202
+ max_new_tokens=200,
203
+ min_length=20,
204
+ temperature=self.temperature,
205
+ do_sample=True,
206
+ top_p=self.top_p,
207
+ top_k=self.top_k,
208
+ repetition_penalty=self.repetition_penalty,
209
+ no_repeat_ngram_size=4,
210
+ use_cache=True,
211
+ pad_token_id=self.tokenizer.pad_token_id,
212
+ eos_token_id=self.tokenizer.eos_token_id
213
+ )
214
+
215
+ # Decode and clean up the response
216
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
217
+
218
+ # Clean up response
219
+ if "gpt:" in response:
220
+ response = response.split("gpt:")[-1].strip()
221
+ if "human:" in response:
222
+ response = response.split("human:")[0].strip()
223
+ if "<image>" in response:
224
+ response = response.replace("<image>", "").strip()
225
+
226
+ self.history.append((message, response))
227
+ return response
228
+
229
+ except Exception as e:
230
+ logging.error(f"Error generating response: {str(e)}")
231
+ logging.error(f"Full traceback:", exc_info=True)
232
+ return f"Error: {str(e)}"
233
+
234
+ def clear_history(self):
235
+ self.history = []
236
+ return None
237
+
238
+ # Add new function to control generation parameters
239
+ def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2):
240
+ """Update generation parameters to control hallucination tendency"""
241
+ self.temperature = temperature
242
+ self.top_p = top_p
243
+ self.top_k = top_k
244
+ self.repetition_penalty = repetition_penalty
245
+ return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}"
246
+
247
+ # New method to apply config file settings
248
+ def apply_config(self, config):
249
+ """Apply settings from config file"""
250
+ model_params = config.get("model_params", {})
251
+ self.model_id = model_params.get("model_name", self.model_id)
252
+ self.max_length = model_params.get("max_length", 512)
253
+
254
+ # Update generation parameters if needed
255
+ training_params = config.get("training_params", {})
256
+ # Could add specific updates based on training_params if needed
257
+
258
+ return f"Applied configuration. Model: {self.model_id}, Max Length: {self.max_length}"
259
+
260
+ def create_demo(config=None):
261
  try:
262
+ # Initialize with config file settings
263
  model = LLaVAPhiModel()
264
 
265
+ if config:
266
+ model.apply_config(config)
267
+
268
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
269
  gr.Markdown(
270
  """
271
  # LLaVA-Phi Demo (Optimized for Accuracy)
 
288
 
289
  image = gr.Image(type="pil", label="Upload Image (Optional)")
290
 
291
+ # Add generation parameter controls
292
+ with gr.Accordion("Advanced Settings (Reduce Hallucinations)", open=False):
293
  gr.Markdown("Adjust these parameters to control hallucination tendency")
294
  temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
295
  top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
296
  top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
297
  rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
298
  update_params = gr.Button("Update Parameters")
299
+
300
+ # Add debugging information box
301
+ debug_info = gr.Textbox(label="Debug Info", interactive=False)
302
+
303
+ # Add config information
304
+ if config:
305
+ config_info = f"Model: {model.model_id}, Max Length: {model.max_length}"
306
+ gr.Markdown(f"**Current Configuration:** {config_info}")
307
 
308
  def respond(message, chat_history, image):
309
  if not message and image is None:
310
+ return chat_history, ""
311
 
312
+ try:
313
+ response = model.generate_response(message, image)
314
+ chat_history.append((message, response))
315
+ debug_msg = "Response generated successfully"
316
+ return "", chat_history, debug_msg
317
+ except Exception as e:
318
+ debug_msg = f"Error: {str(e)}"
319
+ return message, chat_history, debug_msg
320
 
321
  def clear_chat():
322
  model.clear_history()
323
+ return None, None, "Chat history cleared"
324
 
325
  def update_params_fn(temp, top_p, top_k, rep_penalty):
326
+ result = model.update_generation_params(temp, top_p, top_k, rep_penalty)
327
+ return f"Parameters updated: temp={temp}, top_p={top_p}, top_k={top_k}, rep_penalty={rep_penalty}"
328
 
329
  submit.click(
330
  respond,
331
  [msg, chatbot, image],
332
+ [msg, chatbot, debug_info],
333
  )
334
 
335
  clear.click(
336
  clear_chat,
337
  None,
338
+ [chatbot, image, debug_info],
339
  )
340
 
341
  msg.submit(
342
  respond,
343
  [msg, chatbot, image],
344
+ [msg, chatbot, debug_info],
345
  )
346
 
347
  update_params.click(
348
  update_params_fn,
349
  [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
350
+ [debug_info]
351
  )
352
+
353
  return demo
354
  except Exception as e:
355
  logging.error(f"Error creating demo: {str(e)}")
356
  raise
357
 
358
  if __name__ == "__main__":
359
+ # Load config file
360
+ import json
361
+
362
+ try:
363
+ with open("config.json", "r") as f:
364
+ config = json.load(f)
365
+ logging.info("Successfully loaded config file")
366
+ except Exception as e:
367
+ logging.error(f"Error loading config: {str(e)}")
368
+ config = None
369
+
370
+ demo = create_demo(config)
371
+ demo.launch(
372
+ server_name="0.0.0.0",
373
+ server_port=7860,
374
+ share=True
375
+ )