pivot-iterative-visual-optimization commited on
Commit
53ef1bb
·
verified ·
1 Parent(s): c7757d0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +12 -10
  2. vip_runner.py +22 -32
app.py CHANGED
@@ -15,9 +15,9 @@ def run_vip(
15
  n_samples_init,
16
  n_samples_opt,
17
  n_iters,
18
- n_recurssion,
19
  openai_api_key,
20
- progress=gr.Progress(track_tqdm=True),
21
  ):
22
 
23
  if not openai_api_key:
@@ -53,7 +53,7 @@ def run_vip(
53
  }
54
 
55
  vlm = GPT4V(openai_api_key=openai_api_key)
56
- ims, center, _ = vip_runner(
57
  vlm,
58
  im,
59
  query,
@@ -62,9 +62,10 @@ def run_vip(
62
  n_samples_init=n_samples_init,
63
  n_samples_opt=n_samples_opt,
64
  n_iters=n_iters,
65
- recursion_level=n_recurssion,
66
  )
67
- return ims, f'Final selected coordinate: {np.round(center, decimals=0)}'
 
68
 
69
 
70
  examples = [
@@ -117,11 +118,11 @@ The Info textbox will show the final selected pixel coordinate that PIVOT conver
117
  """.strip())
118
 
119
  gr.Markdown(
120
- '## Example Images and Queries\n Drag images into the image box below'
121
  )
122
  with gr.Row(equal_height=True):
123
  for example in examples:
124
- gr.Image(value=example['im_path'], label=example['desc'])
125
 
126
  gr.Markdown('## New Query')
127
  with gr.Row():
@@ -160,8 +161,8 @@ The Info textbox will show the final selected pixel coordinate that PIVOT conver
160
  inp_n_iters = gr.Slider(
161
  label='N Iterations', minimum=1, maximum=5, value=3, step=1
162
  )
163
- inp_n_recurssions = gr.Slider(
164
- label='N Ensemble Recursions', minimum=0, maximum=3, value=0, step=1
165
  )
166
  btn_run = gr.Button('Run')
167
 
@@ -171,6 +172,7 @@ The Info textbox will show the final selected pixel coordinate that PIVOT conver
171
  columns=4,
172
  rows=1,
173
  interactive=False,
 
174
  )
175
  out_info = gr.Textbox(label='Info', lines=1)
176
 
@@ -182,7 +184,7 @@ The Info textbox will show the final selected pixel coordinate that PIVOT conver
182
  inp_n_samples_init,
183
  inp_n_samples_opt,
184
  inp_n_iters,
185
- inp_n_recurssions,
186
  inp_openai_api_key,
187
  ],
188
  outputs=[out_ims, out_info],
 
15
  n_samples_init,
16
  n_samples_opt,
17
  n_iters,
18
+ n_parallel_trials,
19
  openai_api_key,
20
+ progress=gr.Progress(track_tqdm=False),
21
  ):
22
 
23
  if not openai_api_key:
 
53
  }
54
 
55
  vlm = GPT4V(openai_api_key=openai_api_key)
56
+ vip_gen = vip_runner(
57
  vlm,
58
  im,
59
  query,
 
62
  n_samples_init=n_samples_init,
63
  n_samples_opt=n_samples_opt,
64
  n_iters=n_iters,
65
+ n_parallel_trials=n_parallel_trials,
66
  )
67
+ for rst in vip_gen:
68
+ yield rst
69
 
70
 
71
  examples = [
 
118
  """.strip())
119
 
120
  gr.Markdown(
121
+ '## Example Images and Queries\n Drag images into the image box below (Try safari on Mac if dragging does not work)'
122
  )
123
  with gr.Row(equal_height=True):
124
  for example in examples:
125
+ gr.Image(value=example['im_path'], type='numpy', label=example['desc'])
126
 
127
  gr.Markdown('## New Query')
128
  with gr.Row():
 
161
  inp_n_iters = gr.Slider(
162
  label='N Iterations', minimum=1, maximum=5, value=3, step=1
163
  )
164
+ inp_n_parallel_trials = gr.Slider(
165
+ label='N Parallel Trials', minimum=1, maximum=3, value=1, step=1
166
  )
167
  btn_run = gr.Button('Run')
168
 
 
172
  columns=4,
173
  rows=1,
174
  interactive=False,
175
+ object_fit="contain", height="auto"
176
  )
177
  out_info = gr.Textbox(label='Info', lines=1)
178
 
 
184
  inp_n_samples_init,
185
  inp_n_samples_opt,
186
  inp_n_iters,
187
+ inp_n_parallel_trials,
188
  inp_openai_api_key,
189
  ],
190
  outputs=[out_ims, out_info],
vip_runner.py CHANGED
@@ -5,6 +5,7 @@ import re
5
 
6
  import cv2
7
  from tqdm import trange
 
8
  import vip
9
 
10
 
@@ -48,7 +49,11 @@ def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n):
48
  prompt_seq = [make_prompt(desc, top_n=top_n), encoded_image_circles]
49
  response = vlm.query(prompt_seq)
50
 
51
- arrow_ids = extract_json(response, "points")
 
 
 
 
52
  return arrow_ids, image_circles_np
53
 
54
 
@@ -61,7 +66,7 @@ def vip_runner(
61
  n_samples_init=25,
62
  n_samples_opt=10,
63
  n_iters=3,
64
- recursion_level=0,
65
  ):
66
  """VIP."""
67
 
@@ -72,10 +77,11 @@ def vip_runner(
72
  output_ims = []
73
  arm_coord = (int(im.shape[1] / 2), int(im.shape[0] / 2))
74
 
75
- if recursion_level == 0:
 
 
76
  center_mean = action_spec["loc"]
77
  center_std = action_spec["scale"]
78
- selected_samples = []
79
  for itr in trange(n_iters):
80
  if itr == 0:
81
  style["num_samples"] = n_samples_init
@@ -96,6 +102,7 @@ def vip_runner(
96
  image_circles_np, selected_samples, arm_coord
97
  )
98
  output_ims.append(image_circles_marked_np)
 
99
 
100
  # if at last iteration, pick one answer out of the selected ones
101
  if itr == n_iters - 1:
@@ -112,30 +119,11 @@ def vip_runner(
112
  im, selected_samples, arm_coord
113
  )
114
  output_ims.append(image_circles_marked_np)
 
 
115
  center_mean, center_std = prompter.fit(arrow_ids, samples)
116
 
117
- if output_ims:
118
- return (
119
- output_ims,
120
- prompter.action_to_coord(center_mean, im, arm_coord).xy,
121
- selected_samples,
122
- )
123
- else:
124
- new_samples = []
125
- for i in range(3):
126
- out_ims, _, cur_samples = vip_runner(
127
- vlm=vlm,
128
- im=im,
129
- desc=desc,
130
- style=style,
131
- action_spec=action_spec,
132
- n_samples_init=n_samples_init,
133
- n_samples_opt=n_samples_opt,
134
- n_iters=n_iters,
135
- recursion_level=recursion_level - 1,
136
- )
137
- output_ims += out_ims
138
- new_samples += cur_samples
139
  # adjust sample label to avoid duplications
140
  for sample_id in range(len(new_samples)):
141
  new_samples[sample_id].label = str(sample_id)
@@ -154,10 +142,12 @@ def vip_runner(
154
  output_ims.append(image_circles_marked_np)
155
  center_mean, _ = prompter.fit(arrow_ids, new_samples)
156
 
157
- if output_ims:
158
- return (
159
- output_ims,
160
- prompter.action_to_coord(center_mean, im, arm_coord).xy,
161
- selected_samples,
162
- )
 
 
163
  return [], "Unable to understand query"
 
5
 
6
  import cv2
7
  from tqdm import trange
8
+ import numpy as np
9
  import vip
10
 
11
 
 
49
  prompt_seq = [make_prompt(desc, top_n=top_n), encoded_image_circles]
50
  response = vlm.query(prompt_seq)
51
 
52
+ try:
53
+ arrow_ids = extract_json(response, "points")
54
+ except Exception as e:
55
+ print(e)
56
+ arrow_ids = []
57
  return arrow_ids, image_circles_np
58
 
59
 
 
66
  n_samples_init=25,
67
  n_samples_opt=10,
68
  n_iters=3,
69
+ n_parallel_trials=1,
70
  ):
71
  """VIP."""
72
 
 
77
  output_ims = []
78
  arm_coord = (int(im.shape[1] / 2), int(im.shape[0] / 2))
79
 
80
+ new_samples = []
81
+ center_mean = action_spec["loc"]
82
+ for i in range(n_parallel_trials):
83
  center_mean = action_spec["loc"]
84
  center_std = action_spec["scale"]
 
85
  for itr in trange(n_iters):
86
  if itr == 0:
87
  style["num_samples"] = n_samples_init
 
102
  image_circles_np, selected_samples, arm_coord
103
  )
104
  output_ims.append(image_circles_marked_np)
105
+ yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} iteration {itr+1}/{n_iters}. Still working..."
106
 
107
  # if at last iteration, pick one answer out of the selected ones
108
  if itr == n_iters - 1:
 
119
  im, selected_samples, arm_coord
120
  )
121
  output_ims.append(image_circles_marked_np)
122
+ new_samples += selected_samples
123
+ yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} last iteration. Still working..."
124
  center_mean, center_std = prompter.fit(arrow_ids, samples)
125
 
126
+ if n_parallel_trials > 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  # adjust sample label to avoid duplications
128
  for sample_id in range(len(new_samples)):
129
  new_samples[sample_id].label = str(sample_id)
 
142
  output_ims.append(image_circles_marked_np)
143
  center_mean, _ = prompter.fit(arrow_ids, new_samples)
144
 
145
+ if output_ims:
146
+ yield (
147
+ output_ims,
148
+ (
149
+ "Final selected coordinate:"
150
+ f" {np.round(prompter.action_to_coord(center_mean, im, arm_coord).xy, decimals=0)}"
151
+ ),
152
+ )
153
  return [], "Unable to understand query"