qnguyen3 commited on
Commit
0159119
·
verified ·
1 Parent(s): 50bd3d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -131
app.py CHANGED
@@ -6,17 +6,21 @@ from threading import Thread
6
  import re
7
  import time
8
  from PIL import Image
 
9
  import spaces
10
  import subprocess
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
- # Initialize tokenizer (doesn't require CUDA)
14
  tokenizer = AutoTokenizer.from_pretrained(
15
  'qnguyen3/nanoLLaVA-1.5',
16
  trust_remote_code=True)
17
 
18
- # Don't initialize model here - move it to the GPU-decorated function
19
- model = None
 
 
 
 
20
 
21
  class KeywordsStoppingCriteria(StoppingCriteria):
22
  def __init__(self, keywords, tokenizer, input_ids):
@@ -55,154 +59,69 @@ class KeywordsStoppingCriteria(StoppingCriteria):
55
 
56
  @spaces.GPU
57
  def bot_streaming(message, history):
58
- global model
59
-
60
- # Initialize the model inside the GPU-decorated function
61
- if model is None:
62
- model = LlavaQwen2ForCausalLM.from_pretrained(
63
- 'qnguyen3/nanoLLaVA-1.5',
64
- torch_dtype=torch.float16,
65
- attn_implementation="flash_attention_2",
66
- trust_remote_code=True,
67
- device_map="auto") # Use "auto" instead of 'cpu' then manual to('cuda')
68
-
69
- # Get image path
70
- image = None
71
- if "files" in message and message["files"]:
72
- image = message["files"][-1]["path"]
73
-
74
- # Check if image is available
75
- if image is None:
76
- return "Please upload an image for LLaVA to work."
77
-
78
- # Prepare conversation messages
79
  messages = []
80
- if len(history) > 0:
81
- for human, assistant in history:
82
- # Skip None responses (which can happen during streaming)
83
- if assistant is not None:
84
- messages.append({"role": "user", "content": human})
85
- messages.append({"role": "assistant", "content": assistant})
86
- # Add the current message
87
- messages.append({"role": "user", "content": f"<image>\n{message['text']}" if len(messages) == 0 else message['text']})
88
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  messages.append({"role": "user", "content": f"<image>\n{message['text']}"})
 
 
 
90
 
91
- # Process image
 
92
  image = Image.open(image).convert("RGB")
93
-
94
- # Prepare input for generation
95
  text = tokenizer.apply_chat_template(
96
  messages,
97
  tokenize=False,
98
  add_generation_prompt=True)
99
-
100
- # Handle image embedding in text
101
- if '<image>' in text:
102
- text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
103
- input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
104
- else:
105
- # If no <image> tag was added (possible in some chat templates), add it manually
106
- input_ids = tokenizer(text).input_ids
107
- # Find the position to insert the image token
108
- # For simplicity, insert after the user message start
109
- user_start_pos = 0
110
- for i, token in enumerate(input_ids):
111
- if tokenizer.decode([token]) == '<|im_start|>user':
112
- user_start_pos = i + 2 # +2 to get past the tag
113
- break
114
- # Insert image token
115
- input_ids = input_ids[:user_start_pos] + [-200] + input_ids[user_start_pos:]
116
- input_ids = torch.tensor([input_ids], dtype=torch.long)
117
-
118
- # Prepare stopping criteria
119
  stop_str = '<|im_end|>'
120
  keywords = [stop_str]
121
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
122
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
123
 
124
- # Process image and generate text
125
  image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
126
- generation_kwargs = dict(
127
- input_ids=input_ids,
128
- images=image_tensor,
129
- streamer=streamer,
130
- max_new_tokens=512,
131
- stopping_criteria=[stopping_criteria],
132
- temperature=0.01
133
- )
134
-
135
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
136
  thread.start()
 
137
 
138
- # Stream response
139
  buffer = ""
140
  for new_text in streamer:
141
- buffer += new_text
142
- generated_text_without_prompt = buffer[:]
143
- time.sleep(0.04)
144
- yield generated_text_without_prompt
 
 
145
 
146
 
147
- # Create a gradio Blocks interface instead of ChatInterface
148
- # This avoids the schema validation issues
149
- with gr.Blocks(title="🚀nanoLLaVA-1.5") as demo:
150
- gr.Markdown("## 🚀nanoLLaVA-1.5")
151
- gr.Markdown("Try [nanoLLaVA](https://huggingface.co/qnguyen3/nanoLLaVA-1.5) in this demo. Built on top of [Quyen-SE-v0.1](https://huggingface.co/vilm/Quyen-SE-v0.1) (Qwen1.5-0.5B) and [Google SigLIP-400M](https://huggingface.co/google/siglip-so400m-patch14-384). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.")
152
-
153
- chatbot = gr.Chatbot(height=500)
154
- with gr.Row():
155
- with gr.Column(scale=0.8):
156
- msg = gr.Textbox(
157
- show_label=False,
158
- placeholder="Enter text and upload an image",
159
- container=False
160
- )
161
- with gr.Column(scale=0.2):
162
- btn = gr.Button("Submit")
163
- stop_btn = gr.Button("Stop Generation")
164
-
165
- upload_btn = gr.UploadButton("Upload Image", file_types=["image"])
166
- current_img = gr.State(None)
167
-
168
- # Example images
169
- examples = gr.Examples(
170
- examples=[
171
- ["Who is this guy?", "./demo_1.jpg"],
172
- ["What does the text say?", "./demo_2.jpeg"]
173
- ],
174
- inputs=[msg, upload_btn]
175
- )
176
-
177
- def upload_image(image):
178
- return image
179
-
180
- def add_text(history, text, image):
181
- if image is None and (not history or type(history[0][0]) != tuple):
182
- return history + [[text, "Please upload an image first."]]
183
- return history + [[text, None]]
184
-
185
- def bot_response(history, image):
186
- message = {"text": history[-1][0], "files": [{"path": image}] if image else []}
187
- history_format = history[:-1] # All except the last message
188
-
189
- response = ""
190
- for chunk in bot_streaming(message, history_format):
191
- response = chunk
192
- history[-1][1] = response
193
- yield history
194
-
195
- upload_btn.upload(upload_image, upload_btn, current_img)
196
-
197
- msg.submit(add_text, [chatbot, msg, current_img], chatbot).then(
198
- bot_response, [chatbot, current_img], chatbot
199
- )
200
-
201
- btn.click(add_text, [chatbot, msg, current_img], chatbot).then(
202
- bot_response, [chatbot, current_img], chatbot
203
- )
204
-
205
- stop_btn.click(None, None, None, cancels=[bot_response])
206
-
207
- # Launch the app with queuing
208
  demo.queue().launch()
 
6
  import re
7
  import time
8
  from PIL import Image
9
+ import torch
10
  import spaces
11
  import subprocess
12
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
13
 
 
14
  tokenizer = AutoTokenizer.from_pretrained(
15
  'qnguyen3/nanoLLaVA-1.5',
16
  trust_remote_code=True)
17
 
18
+ model = LlavaQwen2ForCausalLM.from_pretrained(
19
+ 'qnguyen3/nanoLLaVA-1.5',
20
+ torch_dtype=torch.float16,
21
+ attn_implementation="flash_attention_2",
22
+ trust_remote_code=True,
23
+ device_map='auto')
24
 
25
  class KeywordsStoppingCriteria(StoppingCriteria):
26
  def __init__(self, keywords, tokenizer, input_ids):
 
59
 
60
  @spaces.GPU
61
  def bot_streaming(message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  messages = []
63
+ if message["files"]:
64
+ image = message["files"][-1]["path"]
 
 
 
 
 
 
65
  else:
66
+ for i, hist in enumerate(history):
67
+ if type(hist[0])==tuple:
68
+ image = hist[0][0]
69
+ image_turn = i
70
+
71
+ if len(history) > 0 and image is not None:
72
+ messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'})
73
+ messages.append({"role": "assistant", "content": history[1][1] })
74
+ for human, assistant in history[2:]:
75
+ messages.append({"role": "user", "content": human })
76
+ messages.append({"role": "assistant", "content": assistant })
77
+ messages.append({"role": "user", "content": message['text']})
78
+ elif len(history) > 0 and image is None:
79
+ for human, assistant in history:
80
+ messages.append({"role": "user", "content": human })
81
+ messages.append({"role": "assistant", "content": assistant })
82
+ messages.append({"role": "user", "content": message['text']})
83
+ elif len(history) == 0 and image is not None:
84
  messages.append({"role": "user", "content": f"<image>\n{message['text']}"})
85
+ elif len(history) == 0 and image is None:
86
+ messages.append({"role": "user", "content": message['text'] })
87
+ model = model.to('cuda')
88
 
89
+ # if image is None:
90
+ # gr.Error("You need to upload an image for LLaVA to work.")
91
  image = Image.open(image).convert("RGB")
 
 
92
  text = tokenizer.apply_chat_template(
93
  messages,
94
  tokenize=False,
95
  add_generation_prompt=True)
96
+ text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
97
+ input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  stop_str = '<|im_end|>'
99
  keywords = [stop_str]
100
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
101
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
102
 
 
103
  image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
104
+ generation_kwargs = dict(input_ids=input_ids.to('cuda'),
105
+ images=image_tensor.to('cuda'),
106
+ streamer=streamer, max_new_tokens=512,
107
+ stopping_criteria=[stopping_criteria], temperature=0.01)
108
+ generated_text = ""
 
 
 
 
109
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
110
  thread.start()
111
+ text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>"
112
 
 
113
  buffer = ""
114
  for new_text in streamer:
115
+
116
+ buffer += new_text
117
+
118
+ generated_text_without_prompt = buffer[:]
119
+ time.sleep(0.04)
120
+ yield generated_text_without_prompt
121
 
122
 
123
+ demo = gr.ChatInterface(fn=bot_streaming, title="🚀nanoLLaVA-1.5", examples=[{"text": "Who is this guy?", "files":["./demo_1.jpg"]},
124
+ {"text": "What does the text say?", "files":["./demo_2.jpeg"]}],
125
+ description="Try [nanoLLaVA](https://huggingface.co/qnguyen3/nanoLLaVA-1.5) in this demo. Built on top of [Quyen-SE-v0.1](https://huggingface.co/vilm/Quyen-SE-v0.1) (Qwen1.5-0.5B) and [Google SigLIP-400M](https://huggingface.co/google/siglip-so400m-patch14-384). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
126
+ stop_btn="Stop Generation", multimodal=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  demo.queue().launch()