John Yang commited on
Commit
4b9c9b6
·
1 Parent(s): 1e7de71

Restore working version

Browse files
Files changed (2) hide show
  1. app.py +237 -1
  2. predict.py +0 -250
app.py CHANGED
@@ -1,13 +1,249 @@
1
  import gradio as gr
2
- from predict import run_episode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  gr.Interface(fn=run_episode,\
5
  inputs=gr.inputs.Textbox(lines=7, label="Input Text"),\
6
  outputs="text",\
7
  examples=[
 
8
  "I want to find a gold floor lamp with a glass shade and a nickel finish that i can use for my living room, and price lower than 270.00 dollars",
9
  "I'm trying to find white bluetooth speakers that are not only water resistant but also come with stereo sound",
10
  "I'm looking for a kids toothbrush for ages 6 to 12 that will help with teeth whitening and is easy to use",
 
11
  ],\
12
  title="WebShop",\
13
  article="<p style='padding-top:15px;text-align:center;'>To learn more about this project, check out the <a href='https://webshop-pnlp.github.io/' target='_blank'>project page</a>!</p>",\
 
1
  import gradio as gr
2
+ import time, torch
3
+ from transformers import BartTokenizer, BartForConditionalGeneration, AutoModel, AutoTokenizer
4
+
5
+ from webshop_lite import dict_to_fake_html
6
+ from predict_help import convert_dict_to_actions, convert_html_to_text, parse_results, parse_item_page, Page
7
+
8
+ # load IL models
9
+ bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
10
+ bart_model = BartForConditionalGeneration.from_pretrained('webshop/il_search_bart')
11
+
12
+ bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')
13
+ bert_tokenizer.add_tokens(['[button]', '[button_]', '[clicked button]', '[clicked button_]'], special_tokens=True)
14
+ bert_model = AutoModel.from_pretrained('webshop/il-choice-bert-image_0', trust_remote_code=True)
15
+
16
+ def process_str(s):
17
+ s = s.lower().replace('"', '').replace("'", "").strip()
18
+ s = s.replace('[sep]', '[SEP]')
19
+ return s
20
+
21
+
22
+ def process_goal(state):
23
+ state = state.lower().replace('"', '').replace("'", "")
24
+ state = state.replace('amazon shopping game\ninstruction:', '').replace('\n[button] search [button_]', '').strip()
25
+ if ', and price lower than' in state:
26
+ state = state.split(', and price lower than')[0]
27
+ return state
28
+
29
+
30
+ def data_collator(batch):
31
+ state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [], [], [], [], [], [], []
32
+ for sample in batch:
33
+ state_input_ids.append(sample['state_input_ids'])
34
+ state_attention_mask.append(sample['state_attention_mask'])
35
+ action_input_ids.extend(sample['action_input_ids'])
36
+ action_attention_mask.extend(sample['action_attention_mask'])
37
+ sizes.append(sample['sizes'])
38
+ labels.append(sample['labels'])
39
+ images.append(sample['images'])
40
+ max_state_len = max(sum(x) for x in state_attention_mask)
41
+ max_action_len = max(sum(x) for x in action_attention_mask)
42
+ return {
43
+ 'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len],
44
+ 'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len],
45
+ 'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len],
46
+ 'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len],
47
+ 'sizes': torch.tensor(sizes),
48
+ 'images': torch.tensor(images),
49
+ 'labels': torch.tensor(labels),
50
+ }
51
+
52
+
53
+ def bart_predict(input):
54
+ input_ids = bart_tokenizer(input)['input_ids']
55
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
56
+ output = bart_model.generate(input_ids, max_length=512, num_return_sequences=5, num_beams=5)
57
+ return bart_tokenizer.batch_decode(output.tolist(), skip_special_tokens=True)[0]
58
+
59
+
60
+ def bert_predict(obs, info, softmax=True):
61
+ valid_acts = info['valid']
62
+ assert valid_acts[0].startswith('click[')
63
+ state_encodings = bert_tokenizer(process_str(obs), max_length=512, truncation=True, padding='max_length')
64
+ action_encodings = bert_tokenizer(list(map(process_str, valid_acts)), max_length=512, truncation=True, padding='max_length')
65
+ batch = {
66
+ 'state_input_ids': state_encodings['input_ids'],
67
+ 'state_attention_mask': state_encodings['attention_mask'],
68
+ 'action_input_ids': action_encodings['input_ids'],
69
+ 'action_attention_mask': action_encodings['attention_mask'],
70
+ 'sizes': len(valid_acts),
71
+ 'images': info['image_feat'].tolist(),
72
+ 'labels': 0
73
+ }
74
+ batch = data_collator([batch])
75
+ outputs = bert_model(**batch)
76
+ if softmax:
77
+ idx = torch.multinomial(torch.nn.functional.softmax(outputs.logits[0], dim=0), 1)[0].item()
78
+ else:
79
+ idx = outputs.logits[0].argmax(0).item()
80
+ return valid_acts[idx]
81
+
82
+
83
+ def predict(obs, info):
84
+ """
85
+ Given WebShop environment observation and info, predict an action.
86
+ """
87
+ valid_acts = info['valid']
88
+ if valid_acts[0].startswith('click['):
89
+ return bert_predict(obs, info)
90
+ else:
91
+ return "search[" + bart_predict(process_goal(obs)) + "]"
92
+
93
+ def run_episode(goal, verbose=True):
94
+ """
95
+ Interact with amazon to find a product given input goal.
96
+ Input: text goal
97
+ Output: a url of found item on amazon.
98
+ """
99
+ obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]"
100
+ info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)}
101
+ product_map = {}
102
+ title_to_asin_map = {}
103
+ search_results_cache = {}
104
+ visited_asins, clicked_options = set(), set()
105
+ sub_page_type, page_type, page_num = None, None, None
106
+ search_terms, prod_title, asin, num_prods, = None, None, None, None
107
+ options = {}
108
+
109
+ for i in range(100):
110
+ # Run prediction
111
+ action = predict(obs, info)
112
+ if verbose:
113
+ print("====")
114
+ print(action)
115
+
116
+ # Previous Page Type, Action -> Next Page Type
117
+ action_content = action[action.find("[")+1:action.find("]")]
118
+ prev_page_type = page_type
119
+ if action.startswith('search['):
120
+ page_type = Page.RESULTS
121
+ search_terms = action_content
122
+ page_num = 1
123
+ elif action.startswith('click['):
124
+ if action.startswith('click[item -'):
125
+ prod_title = action_content[len("item -"):].strip()
126
+ found = False
127
+ for key in title_to_asin_map:
128
+ if prod_title == key:
129
+ asin = title_to_asin_map[key]
130
+ page_type = Page.ITEM_PAGE
131
+ visited_asins.add(asin)
132
+ found = True
133
+ break
134
+ if not found:
135
+ raise Exception("Product to click not found")
136
+
137
+ elif any(x.value in action for x in [Page.DESC, Page.FEATURES, Page.REVIEWS]):
138
+ page_type = Page.SUB_PAGE
139
+ sub_page_type = Page(action_content.lower())
140
+
141
+ elif action == 'click[< prev]':
142
+ if sub_page_type is not None:
143
+ page_type, sub_page_type = Page.ITEM_PAGE, None
144
+ elif prev_page_type == Page.ITEM_PAGE:
145
+ page_type = Page.RESULTS
146
+ options, clicked_options = {}, set()
147
+ elif prev_page_type == Page.RESULTS and page_num > 1:
148
+ page_type = Page.RESULTS
149
+ page_num -= 1
150
+
151
+ elif action == 'click[next >]':
152
+ page_type = Page.RESULTS
153
+ page_num += 1
154
+
155
+ elif action.lower() == 'click[back to search]':
156
+ page_type = Page.SEARCH
157
+
158
+ elif action == 'click[buy now]':
159
+ asin_url = f"https://www.amazon.com/dp/{asin}"
160
+ return_value = "Product URL: " + asin_url
161
+ if len(clicked_options) > 0:
162
+ options_str = ', '.join(list(clicked_options))
163
+ return_value += "\nSelected Options: " + options_str
164
+ return return_value
165
+
166
+ elif prev_page_type == Page.ITEM_PAGE:
167
+ found = False
168
+ for opt_name, opt_values in product_map[asin]["options"].items():
169
+ if action_content in opt_values:
170
+ options[opt_name] = action_content
171
+ page_type = Page.ITEM_PAGE
172
+ clicked_options.add(action_content)
173
+ found = True
174
+ break
175
+ if not found:
176
+ raise Exception("Unrecognized action: " + action)
177
+ else:
178
+ raise Exception("Unrecognized action:" + action)
179
+
180
+ if verbose:
181
+ print(f"Parsing {page_type.value} page...")
182
+
183
+ # URL -> Real HTML -> Dict of Info
184
+ if page_type == Page.RESULTS:
185
+ if search_terms in search_results_cache:
186
+ data = search_results_cache[search_terms]
187
+ else:
188
+ begin = time.time()
189
+ data = parse_results(search_terms, page_num)
190
+ end = time.time()
191
+ print("Parsing search results took", end-begin, "seconds")
192
+
193
+ search_results_cache[search_terms] = data
194
+ num_prods = len(data)
195
+ for d in data:
196
+ title_to_asin_map[d['Title']] = d['asin']
197
+ elif page_type == Page.ITEM_PAGE or page_type == Page.SUB_PAGE:
198
+ if asin in product_map:
199
+ print("Loading cached item page for", asin)
200
+ data = product_map[asin]
201
+ else:
202
+ begin = time.time()
203
+ data = parse_item_page(asin)
204
+ end = time.time()
205
+ print("Parsing item page took", end-begin, "seconds")
206
+ product_map[asin] = data
207
+ elif page_type == Page.SEARCH:
208
+ if verbose:
209
+ print("Executing search")
210
+ obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]"
211
+ info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)}
212
+ continue
213
+ else:
214
+ raise Exception("Page of type `", page_type, "` not found")
215
+
216
+ # Dict of Info -> Fake HTML -> Text Observation
217
+ begin = time.time()
218
+ html_str = dict_to_fake_html(data, page_type, asin, sub_page_type, options, product_map, goal)
219
+ obs = convert_html_to_text(html_str, simple=False, clicked_options=clicked_options, visited_asins=visited_asins)
220
+ end = time.time()
221
+ print("[Page Info -> WebShop HTML -> Observation] took", end-begin, "seconds")
222
+
223
+ # Dict of Info -> Valid Action State (Info)
224
+ begin = time.time()
225
+ prod_arg = product_map if page_type == Page.ITEM_PAGE else data
226
+ info = convert_dict_to_actions(page_type, prod_arg, asin, page_num, num_prods)
227
+ end = time.time()
228
+ print("Extracting available actions took", end-begin, "seconds")
229
+
230
+ if i == 99:
231
+ asin_url = f"https://www.amazon.com/dp/{asin}"
232
+ return_value = "Product URL: " + asin_url
233
+ if len(clicked_options) > 0:
234
+ options_str = ', '.join(list(clicked_options))
235
+ return_value += "\nSelected Options: " + options_str
236
+ return return_value
237
 
238
  gr.Interface(fn=run_episode,\
239
  inputs=gr.inputs.Textbox(lines=7, label="Input Text"),\
240
  outputs="text",\
241
  examples=[
242
+ "Please select a 1 pound, certified organic sea salt shaker in the flavor triple blend flakes, and price lower than 40.00 dollars",
243
  "I want to find a gold floor lamp with a glass shade and a nickel finish that i can use for my living room, and price lower than 270.00 dollars",
244
  "I'm trying to find white bluetooth speakers that are not only water resistant but also come with stereo sound",
245
  "I'm looking for a kids toothbrush for ages 6 to 12 that will help with teeth whitening and is easy to use",
246
+ "I need some cute heart-shaped glittery cupcake picks as a gift to bring to a baby shower",
247
  ],\
248
  title="WebShop",\
249
  article="<p style='padding-top:15px;text-align:center;'>To learn more about this project, check out the <a href='https://webshop-pnlp.github.io/' target='_blank'>project page</a>!</p>",\
predict.py DELETED
@@ -1,250 +0,0 @@
1
- import time, torch
2
- from transformers import BartTokenizer, BartForConditionalGeneration, AutoModel, AutoTokenizer
3
-
4
- from webshop_lite import dict_to_fake_html
5
- from predict_help import convert_dict_to_actions, convert_html_to_text, parse_results, parse_item_page, Page
6
-
7
- # Configurations
8
- DETAILED_OUTPUT = True
9
- BART_MODEL_PATH = 'webshop/il_search_bart'
10
- BERT_MODEL_PATH = 'webshop/il-rl-choice-bert-image_1'
11
-
12
- # load IL models
13
- bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
14
- bart_model = BartForConditionalGeneration.from_pretrained(BART_MODEL_PATH)
15
-
16
- bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')
17
- bert_tokenizer.add_tokens(['[button]', '[button_]', '[clicked button]', '[clicked button_]'], special_tokens=True)
18
- bert_model = AutoModel.from_pretrained(BERT_MODEL_PATH, trust_remote_code=True)
19
-
20
- def process_str(s):
21
- s = s.lower().replace('"', '').replace("'", "").strip()
22
- s = s.replace('[sep]', '[SEP]')
23
- return s
24
-
25
-
26
- def process_goal(state):
27
- state = state.lower().replace('"', '').replace("'", "")
28
- state = state.replace('amazon shopping game\ninstruction:', '').replace('\n[button] search [button_]', '').strip()
29
- if ', and price lower than' in state:
30
- state = state.split(', and price lower than')[0]
31
- return state
32
-
33
-
34
- def data_collator(batch):
35
- state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [], [], [], [], [], [], []
36
- for sample in batch:
37
- state_input_ids.append(sample['state_input_ids'])
38
- state_attention_mask.append(sample['state_attention_mask'])
39
- action_input_ids.extend(sample['action_input_ids'])
40
- action_attention_mask.extend(sample['action_attention_mask'])
41
- sizes.append(sample['sizes'])
42
- labels.append(sample['labels'])
43
- images.append(sample['images'])
44
- max_state_len = max(sum(x) for x in state_attention_mask)
45
- max_action_len = max(sum(x) for x in action_attention_mask)
46
- return {
47
- 'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len],
48
- 'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len],
49
- 'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len],
50
- 'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len],
51
- 'sizes': torch.tensor(sizes),
52
- 'images': torch.tensor(images),
53
- 'labels': torch.tensor(labels),
54
- }
55
-
56
-
57
- def bart_predict(input):
58
- input_ids = bart_tokenizer(input)['input_ids']
59
- input_ids = torch.tensor(input_ids).unsqueeze(0)
60
- output = bart_model.generate(input_ids, max_length=512, num_return_sequences=5, num_beams=5)
61
- return bart_tokenizer.batch_decode(output.tolist(), skip_special_tokens=True)[0]
62
-
63
-
64
- def bert_predict(obs, info, softmax=True):
65
- valid_acts = info['valid']
66
- assert valid_acts[0].startswith('click[')
67
- state_encodings = bert_tokenizer(process_str(obs), max_length=512, truncation=True, padding='max_length')
68
- action_encodings = bert_tokenizer(list(map(process_str, valid_acts)), max_length=512, truncation=True, padding='max_length')
69
- batch = {
70
- 'state_input_ids': state_encodings['input_ids'],
71
- 'state_attention_mask': state_encodings['attention_mask'],
72
- 'action_input_ids': action_encodings['input_ids'],
73
- 'action_attention_mask': action_encodings['attention_mask'],
74
- 'sizes': len(valid_acts),
75
- 'images': info['image_feat'].tolist(),
76
- 'labels': 0
77
- }
78
- batch = data_collator([batch])
79
- outputs = bert_model(**batch)
80
- if softmax:
81
- idx = torch.multinomial(torch.nn.functional.softmax(outputs.logits[0], dim=0), 1)[0].item()
82
- else:
83
- idx = outputs.logits[0].argmax(0).item()
84
- return valid_acts[idx]
85
-
86
-
87
- def predict(obs, info):
88
- """
89
- Given WebShop environment observation and info, predict an action.
90
- """
91
- valid_acts = info['valid']
92
- if valid_acts[0].startswith('click['):
93
- return bert_predict(obs, info)
94
- else:
95
- return "search[" + bart_predict(process_goal(obs)) + "]"
96
-
97
- def run_episode(goal, verbose=True):
98
- """
99
- Interact with amazon to find a product given input goal.
100
- Input: text goal
101
- Output: a url of found item on amazon.
102
- """
103
- obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]"
104
- info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)}
105
- product_map = {}
106
- title_to_asin_map = {}
107
- search_results_cache = {}
108
- visited_asins, clicked_options = set(), set()
109
- sub_page_type, page_type, page_num = None, None, None
110
- search_terms, prod_title, asin, num_prods, = None, None, None, None
111
- options = {}
112
-
113
- for i in range(100):
114
- # Run prediction
115
- action = predict(obs, info)
116
- if verbose:
117
- print("====\n" + action)
118
-
119
- # Previous Page Type, Action -> Next Page Type
120
- action_content = action[action.find("[")+1:action.find("]")]
121
- prev_page_type = page_type
122
- if action.startswith('search['):
123
- page_type = Page.RESULTS
124
- search_terms = action_content
125
- page_num = 1
126
- elif action.startswith('click['):
127
- if action.startswith('click[item -'):
128
- prod_title = action_content[len("item -"):].strip()
129
- found = False
130
- for key in title_to_asin_map:
131
- if prod_title == key:
132
- asin = title_to_asin_map[key]
133
- page_type = Page.ITEM_PAGE
134
- visited_asins.add(asin)
135
- found = True
136
- break
137
- if not found:
138
- raise Exception("Product to click not found")
139
-
140
- elif any(x.value in action for x in [Page.DESC, Page.FEATURES, Page.REVIEWS]):
141
- page_type = Page.SUB_PAGE
142
- sub_page_type = Page(action_content.lower())
143
-
144
- elif action == 'click[< prev]':
145
- if sub_page_type is not None:
146
- page_type, sub_page_type = Page.ITEM_PAGE, None
147
- elif prev_page_type == Page.ITEM_PAGE:
148
- page_type = Page.RESULTS
149
- options, clicked_options = {}, set()
150
- elif prev_page_type == Page.RESULTS and page_num > 1:
151
- page_type = Page.RESULTS
152
- page_num -= 1
153
-
154
- elif action == 'click[next >]':
155
- page_type = Page.RESULTS
156
- page_num += 1
157
-
158
- elif action.lower() == 'click[back to search]':
159
- page_type = Page.SEARCH
160
-
161
- elif action == 'click[buy now]':
162
- if DETAILED_OUTPUT:
163
- asin_url = f"https://www.amazon.com/dp/{asin}"
164
- return_value = "Product URL: " + asin_url
165
- if len(clicked_options) > 0:
166
- options_str = ', '.join(list(clicked_options))
167
- return_value += "\nSelected Options: " + options_str
168
- return return_value
169
- else:
170
- return asin
171
-
172
- elif prev_page_type == Page.ITEM_PAGE:
173
- found = False
174
- for opt_name, opt_values in product_map[asin]["options"].items():
175
- if action_content in opt_values:
176
- options[opt_name] = action_content
177
- page_type = Page.ITEM_PAGE
178
- clicked_options.add(action_content)
179
- found = True
180
- break
181
- if not found:
182
- raise Exception("Unrecognized action: " + action)
183
- else:
184
- raise Exception("Unrecognized action:" + action)
185
-
186
- if verbose:
187
- print(f"Parsing {page_type.value} page...")
188
-
189
- # URL -> Real HTML -> Dict of Info
190
- if page_type == Page.RESULTS:
191
- if search_terms in search_results_cache:
192
- data = search_results_cache[search_terms]
193
- else:
194
- begin = time.time()
195
- data = parse_results(search_terms, page_num, verbose)
196
- end = time.time()
197
- if verbose:
198
- print("Parsing search results took", end-begin, "seconds")
199
-
200
- search_results_cache[search_terms] = data
201
- num_prods = len(data)
202
- for d in data:
203
- title_to_asin_map[d['Title']] = d['asin']
204
- elif page_type == Page.ITEM_PAGE or page_type == Page.SUB_PAGE:
205
- if asin in product_map:
206
- if verbose:
207
- print("Loading cached item page for", asin)
208
- data = product_map[asin]
209
- else:
210
- begin = time.time()
211
- data = parse_item_page(asin, verbose)
212
- end = time.time()
213
- if verbose:
214
- print("Parsing item page took", end-begin, "seconds")
215
- product_map[asin] = data
216
- elif page_type == Page.SEARCH:
217
- if verbose:
218
- print("Executing search")
219
- obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]"
220
- info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)}
221
- continue
222
- else:
223
- raise Exception("Page of type `", page_type, "` not found")
224
-
225
- # Dict of Info -> Fake HTML -> Text Observation
226
- begin = time.time()
227
- html_str = dict_to_fake_html(data, page_type, asin, sub_page_type, options, product_map, goal)
228
- obs = convert_html_to_text(html_str, simple=False, clicked_options=clicked_options, visited_asins=visited_asins)
229
- end = time.time()
230
- if verbose:
231
- print("[Page Info -> WebShop HTML -> Observation] took", end-begin, "seconds")
232
-
233
- # Dict of Info -> Valid Action State (Info)
234
- begin = time.time()
235
- prod_arg = product_map if page_type == Page.ITEM_PAGE else data
236
- info = convert_dict_to_actions(page_type, prod_arg, asin, page_num, num_prods)
237
- end = time.time()
238
- if verbose:
239
- print("Extracting available actions took", end-begin, "seconds")
240
-
241
- if i == 99:
242
- if DETAILED_OUTPUT:
243
- asin_url = f"https://www.amazon.com/dp/{asin}"
244
- return_value = "Product URL: " + asin_url
245
- if len(clicked_options) > 0:
246
- options_str = ', '.join(list(clicked_options))
247
- return_value += "\nSelected Options: " + options_str
248
- return return_value
249
- else:
250
- return asin