JosephZ commited on
Commit
8cebbf6
·
verified ·
1 Parent(s): 1873cee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -29
app.py CHANGED
@@ -10,7 +10,6 @@ import torch
10
  from transformers import Qwen2VLForConditionalGeneration, GenerationConfig, AutoProcessor
11
  import spaces
12
 
13
- from vllm import LLM, SamplingParams
14
 
15
  def extract_answer_content(text: str) -> str:
16
  """
@@ -63,10 +62,6 @@ SYSTEM_PROMPT = (
63
  processor = AutoProcessor.from_pretrained("JosephZ/qwen2vl-7b-sft-grpo-close-sgg", max_pixels=1024*28*28)
64
 
65
  device='cuda' if torch.cuda.is_available() else "cpu"
66
- model_name = "JosephZ/qwen2vl-7b-sft-grpo-close-sgg"
67
-
68
-
69
- """
70
  model = Qwen2VLForConditionalGeneration.from_pretrained("JosephZ/qwen2vl-7b-sft-grpo-close-sgg",
71
  torch_dtype=torch.bfloat16,
72
  device_map=device)
@@ -80,25 +75,9 @@ generation_config=GenerationConfig(
80
  max_new_tokens=2048,
81
  use_cache=True
82
  )
83
- """
84
- model = LLM(
85
- model=model_name,
86
- limit_mm_per_prompt={"image": 1},
87
- dtype='bfloat16',
88
- #device=device,
89
- max_model_len=4096,
90
- mm_processor_kwargs= { "max_pixels": 1024*28*28, "min_pixels": 4*28*28},
91
- )
92
- sampling_params = SamplingParams(
93
- temperature=0.01,
94
- top_k=1,
95
- top_p=0.001,
96
- repetition_penalty=1.0,
97
- max_tokens=2048,
98
- )
99
 
100
  def build_prompt(image, user_text):
101
- base64_image = encode_image_to_base64(image)
102
  messages = [
103
  {
104
  "role": "system",
@@ -107,8 +86,8 @@ def build_prompt(image, user_text):
107
  {
108
  "role": "user",
109
  "content": [
110
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
111
- # {"type": "image"},
112
  {"type": "text", "text": user_text},
113
  ],
114
  },
@@ -176,17 +155,30 @@ def scale_box(box, scale):
176
  def generate_sgg(image):
177
  global model
178
 
 
 
 
179
 
180
  iw, ih = image.size
181
  scale_factors = (iw / 1000.0, ih / 1000.0)
182
-
183
  conversation = build_prompt(image, PROMPT_CLOSE)
 
184
 
 
 
 
 
185
  with torch.no_grad():
186
- outputs = model.chat([conversation], sampling_params=sampling_params)
187
- output_texts = [output.outputs[0].text for output in outputs]
 
 
 
 
 
 
188
 
189
- output_text = output_texts[0]
190
  resp = extract_answer_content(output_text)
191
 
192
  try:
@@ -226,4 +218,4 @@ gr.Interface(
226
  outputs=[gr.Image(type="pil"), gr.Textbox(label="Scene Graph")],
227
  title="R1-SGG: Compile Scene Graphs with Reinforcement Learning",
228
  description="Upload an image and generate a structured scene graph in JSON format."
229
- ).launch(share=True)
 
10
  from transformers import Qwen2VLForConditionalGeneration, GenerationConfig, AutoProcessor
11
  import spaces
12
 
 
13
 
14
  def extract_answer_content(text: str) -> str:
15
  """
 
62
  processor = AutoProcessor.from_pretrained("JosephZ/qwen2vl-7b-sft-grpo-close-sgg", max_pixels=1024*28*28)
63
 
64
  device='cuda' if torch.cuda.is_available() else "cpu"
 
 
 
 
65
  model = Qwen2VLForConditionalGeneration.from_pretrained("JosephZ/qwen2vl-7b-sft-grpo-close-sgg",
66
  torch_dtype=torch.bfloat16,
67
  device_map=device)
 
75
  max_new_tokens=2048,
76
  use_cache=True
77
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def build_prompt(image, user_text):
80
+ #base64_image = encode_image_to_base64(image)
81
  messages = [
82
  {
83
  "role": "system",
 
86
  {
87
  "role": "user",
88
  "content": [
89
+ #{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
90
+ {"type": "image"},
91
  {"type": "text", "text": user_text},
92
  ],
93
  },
 
155
  def generate_sgg(image):
156
  global model
157
 
158
+ device='cuda' if torch.cuda.is_available() else "cpu"
159
+ if next(model.parameters()).device != torch.device(device):
160
+ model = model.to(device)
161
 
162
  iw, ih = image.size
163
  scale_factors = (iw / 1000.0, ih / 1000.0)
164
+
165
  conversation = build_prompt(image, PROMPT_CLOSE)
166
+ text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
167
 
168
+ inputs = processor(
169
+ text=[text_prompt], images=[image], padding=True, return_tensors="pt"
170
+ )
171
+ inputs = inputs.to(model.device)
172
  with torch.no_grad():
173
+ output_ids = model.generate(**inputs, generation_config=generation_config)
174
+ generated_ids = [
175
+ output_ids[len(input_ids) :]
176
+ for input_ids, output_ids in zip(inputs.input_ids, output_ids)
177
+ ]
178
+ output_text = processor.batch_decode(
179
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
180
+ )[0]
181
 
 
182
  resp = extract_answer_content(output_text)
183
 
184
  try:
 
218
  outputs=[gr.Image(type="pil"), gr.Textbox(label="Scene Graph")],
219
  title="R1-SGG: Compile Scene Graphs with Reinforcement Learning",
220
  description="Upload an image and generate a structured scene graph in JSON format."
221
+ ).launch(share=True)