Spaces:
Runtime error
Runtime error
John Yang
commited on
Commit
·
4b9c9b6
1
Parent(s):
1e7de71
Restore working version
Browse files- app.py +237 -1
- predict.py +0 -250
app.py
CHANGED
@@ -1,13 +1,249 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
gr.Interface(fn=run_episode,\
|
5 |
inputs=gr.inputs.Textbox(lines=7, label="Input Text"),\
|
6 |
outputs="text",\
|
7 |
examples=[
|
|
|
8 |
"I want to find a gold floor lamp with a glass shade and a nickel finish that i can use for my living room, and price lower than 270.00 dollars",
|
9 |
"I'm trying to find white bluetooth speakers that are not only water resistant but also come with stereo sound",
|
10 |
"I'm looking for a kids toothbrush for ages 6 to 12 that will help with teeth whitening and is easy to use",
|
|
|
11 |
],\
|
12 |
title="WebShop",\
|
13 |
article="<p style='padding-top:15px;text-align:center;'>To learn more about this project, check out the <a href='https://webshop-pnlp.github.io/' target='_blank'>project page</a>!</p>",\
|
|
|
1 |
import gradio as gr
|
2 |
+
import time, torch
|
3 |
+
from transformers import BartTokenizer, BartForConditionalGeneration, AutoModel, AutoTokenizer
|
4 |
+
|
5 |
+
from webshop_lite import dict_to_fake_html
|
6 |
+
from predict_help import convert_dict_to_actions, convert_html_to_text, parse_results, parse_item_page, Page
|
7 |
+
|
8 |
+
# load IL models
|
9 |
+
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
10 |
+
bart_model = BartForConditionalGeneration.from_pretrained('webshop/il_search_bart')
|
11 |
+
|
12 |
+
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')
|
13 |
+
bert_tokenizer.add_tokens(['[button]', '[button_]', '[clicked button]', '[clicked button_]'], special_tokens=True)
|
14 |
+
bert_model = AutoModel.from_pretrained('webshop/il-choice-bert-image_0', trust_remote_code=True)
|
15 |
+
|
16 |
+
def process_str(s):
|
17 |
+
s = s.lower().replace('"', '').replace("'", "").strip()
|
18 |
+
s = s.replace('[sep]', '[SEP]')
|
19 |
+
return s
|
20 |
+
|
21 |
+
|
22 |
+
def process_goal(state):
|
23 |
+
state = state.lower().replace('"', '').replace("'", "")
|
24 |
+
state = state.replace('amazon shopping game\ninstruction:', '').replace('\n[button] search [button_]', '').strip()
|
25 |
+
if ', and price lower than' in state:
|
26 |
+
state = state.split(', and price lower than')[0]
|
27 |
+
return state
|
28 |
+
|
29 |
+
|
30 |
+
def data_collator(batch):
|
31 |
+
state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [], [], [], [], [], [], []
|
32 |
+
for sample in batch:
|
33 |
+
state_input_ids.append(sample['state_input_ids'])
|
34 |
+
state_attention_mask.append(sample['state_attention_mask'])
|
35 |
+
action_input_ids.extend(sample['action_input_ids'])
|
36 |
+
action_attention_mask.extend(sample['action_attention_mask'])
|
37 |
+
sizes.append(sample['sizes'])
|
38 |
+
labels.append(sample['labels'])
|
39 |
+
images.append(sample['images'])
|
40 |
+
max_state_len = max(sum(x) for x in state_attention_mask)
|
41 |
+
max_action_len = max(sum(x) for x in action_attention_mask)
|
42 |
+
return {
|
43 |
+
'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len],
|
44 |
+
'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len],
|
45 |
+
'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len],
|
46 |
+
'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len],
|
47 |
+
'sizes': torch.tensor(sizes),
|
48 |
+
'images': torch.tensor(images),
|
49 |
+
'labels': torch.tensor(labels),
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
def bart_predict(input):
|
54 |
+
input_ids = bart_tokenizer(input)['input_ids']
|
55 |
+
input_ids = torch.tensor(input_ids).unsqueeze(0)
|
56 |
+
output = bart_model.generate(input_ids, max_length=512, num_return_sequences=5, num_beams=5)
|
57 |
+
return bart_tokenizer.batch_decode(output.tolist(), skip_special_tokens=True)[0]
|
58 |
+
|
59 |
+
|
60 |
+
def bert_predict(obs, info, softmax=True):
|
61 |
+
valid_acts = info['valid']
|
62 |
+
assert valid_acts[0].startswith('click[')
|
63 |
+
state_encodings = bert_tokenizer(process_str(obs), max_length=512, truncation=True, padding='max_length')
|
64 |
+
action_encodings = bert_tokenizer(list(map(process_str, valid_acts)), max_length=512, truncation=True, padding='max_length')
|
65 |
+
batch = {
|
66 |
+
'state_input_ids': state_encodings['input_ids'],
|
67 |
+
'state_attention_mask': state_encodings['attention_mask'],
|
68 |
+
'action_input_ids': action_encodings['input_ids'],
|
69 |
+
'action_attention_mask': action_encodings['attention_mask'],
|
70 |
+
'sizes': len(valid_acts),
|
71 |
+
'images': info['image_feat'].tolist(),
|
72 |
+
'labels': 0
|
73 |
+
}
|
74 |
+
batch = data_collator([batch])
|
75 |
+
outputs = bert_model(**batch)
|
76 |
+
if softmax:
|
77 |
+
idx = torch.multinomial(torch.nn.functional.softmax(outputs.logits[0], dim=0), 1)[0].item()
|
78 |
+
else:
|
79 |
+
idx = outputs.logits[0].argmax(0).item()
|
80 |
+
return valid_acts[idx]
|
81 |
+
|
82 |
+
|
83 |
+
def predict(obs, info):
|
84 |
+
"""
|
85 |
+
Given WebShop environment observation and info, predict an action.
|
86 |
+
"""
|
87 |
+
valid_acts = info['valid']
|
88 |
+
if valid_acts[0].startswith('click['):
|
89 |
+
return bert_predict(obs, info)
|
90 |
+
else:
|
91 |
+
return "search[" + bart_predict(process_goal(obs)) + "]"
|
92 |
+
|
93 |
+
def run_episode(goal, verbose=True):
|
94 |
+
"""
|
95 |
+
Interact with amazon to find a product given input goal.
|
96 |
+
Input: text goal
|
97 |
+
Output: a url of found item on amazon.
|
98 |
+
"""
|
99 |
+
obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]"
|
100 |
+
info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)}
|
101 |
+
product_map = {}
|
102 |
+
title_to_asin_map = {}
|
103 |
+
search_results_cache = {}
|
104 |
+
visited_asins, clicked_options = set(), set()
|
105 |
+
sub_page_type, page_type, page_num = None, None, None
|
106 |
+
search_terms, prod_title, asin, num_prods, = None, None, None, None
|
107 |
+
options = {}
|
108 |
+
|
109 |
+
for i in range(100):
|
110 |
+
# Run prediction
|
111 |
+
action = predict(obs, info)
|
112 |
+
if verbose:
|
113 |
+
print("====")
|
114 |
+
print(action)
|
115 |
+
|
116 |
+
# Previous Page Type, Action -> Next Page Type
|
117 |
+
action_content = action[action.find("[")+1:action.find("]")]
|
118 |
+
prev_page_type = page_type
|
119 |
+
if action.startswith('search['):
|
120 |
+
page_type = Page.RESULTS
|
121 |
+
search_terms = action_content
|
122 |
+
page_num = 1
|
123 |
+
elif action.startswith('click['):
|
124 |
+
if action.startswith('click[item -'):
|
125 |
+
prod_title = action_content[len("item -"):].strip()
|
126 |
+
found = False
|
127 |
+
for key in title_to_asin_map:
|
128 |
+
if prod_title == key:
|
129 |
+
asin = title_to_asin_map[key]
|
130 |
+
page_type = Page.ITEM_PAGE
|
131 |
+
visited_asins.add(asin)
|
132 |
+
found = True
|
133 |
+
break
|
134 |
+
if not found:
|
135 |
+
raise Exception("Product to click not found")
|
136 |
+
|
137 |
+
elif any(x.value in action for x in [Page.DESC, Page.FEATURES, Page.REVIEWS]):
|
138 |
+
page_type = Page.SUB_PAGE
|
139 |
+
sub_page_type = Page(action_content.lower())
|
140 |
+
|
141 |
+
elif action == 'click[< prev]':
|
142 |
+
if sub_page_type is not None:
|
143 |
+
page_type, sub_page_type = Page.ITEM_PAGE, None
|
144 |
+
elif prev_page_type == Page.ITEM_PAGE:
|
145 |
+
page_type = Page.RESULTS
|
146 |
+
options, clicked_options = {}, set()
|
147 |
+
elif prev_page_type == Page.RESULTS and page_num > 1:
|
148 |
+
page_type = Page.RESULTS
|
149 |
+
page_num -= 1
|
150 |
+
|
151 |
+
elif action == 'click[next >]':
|
152 |
+
page_type = Page.RESULTS
|
153 |
+
page_num += 1
|
154 |
+
|
155 |
+
elif action.lower() == 'click[back to search]':
|
156 |
+
page_type = Page.SEARCH
|
157 |
+
|
158 |
+
elif action == 'click[buy now]':
|
159 |
+
asin_url = f"https://www.amazon.com/dp/{asin}"
|
160 |
+
return_value = "Product URL: " + asin_url
|
161 |
+
if len(clicked_options) > 0:
|
162 |
+
options_str = ', '.join(list(clicked_options))
|
163 |
+
return_value += "\nSelected Options: " + options_str
|
164 |
+
return return_value
|
165 |
+
|
166 |
+
elif prev_page_type == Page.ITEM_PAGE:
|
167 |
+
found = False
|
168 |
+
for opt_name, opt_values in product_map[asin]["options"].items():
|
169 |
+
if action_content in opt_values:
|
170 |
+
options[opt_name] = action_content
|
171 |
+
page_type = Page.ITEM_PAGE
|
172 |
+
clicked_options.add(action_content)
|
173 |
+
found = True
|
174 |
+
break
|
175 |
+
if not found:
|
176 |
+
raise Exception("Unrecognized action: " + action)
|
177 |
+
else:
|
178 |
+
raise Exception("Unrecognized action:" + action)
|
179 |
+
|
180 |
+
if verbose:
|
181 |
+
print(f"Parsing {page_type.value} page...")
|
182 |
+
|
183 |
+
# URL -> Real HTML -> Dict of Info
|
184 |
+
if page_type == Page.RESULTS:
|
185 |
+
if search_terms in search_results_cache:
|
186 |
+
data = search_results_cache[search_terms]
|
187 |
+
else:
|
188 |
+
begin = time.time()
|
189 |
+
data = parse_results(search_terms, page_num)
|
190 |
+
end = time.time()
|
191 |
+
print("Parsing search results took", end-begin, "seconds")
|
192 |
+
|
193 |
+
search_results_cache[search_terms] = data
|
194 |
+
num_prods = len(data)
|
195 |
+
for d in data:
|
196 |
+
title_to_asin_map[d['Title']] = d['asin']
|
197 |
+
elif page_type == Page.ITEM_PAGE or page_type == Page.SUB_PAGE:
|
198 |
+
if asin in product_map:
|
199 |
+
print("Loading cached item page for", asin)
|
200 |
+
data = product_map[asin]
|
201 |
+
else:
|
202 |
+
begin = time.time()
|
203 |
+
data = parse_item_page(asin)
|
204 |
+
end = time.time()
|
205 |
+
print("Parsing item page took", end-begin, "seconds")
|
206 |
+
product_map[asin] = data
|
207 |
+
elif page_type == Page.SEARCH:
|
208 |
+
if verbose:
|
209 |
+
print("Executing search")
|
210 |
+
obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]"
|
211 |
+
info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)}
|
212 |
+
continue
|
213 |
+
else:
|
214 |
+
raise Exception("Page of type `", page_type, "` not found")
|
215 |
+
|
216 |
+
# Dict of Info -> Fake HTML -> Text Observation
|
217 |
+
begin = time.time()
|
218 |
+
html_str = dict_to_fake_html(data, page_type, asin, sub_page_type, options, product_map, goal)
|
219 |
+
obs = convert_html_to_text(html_str, simple=False, clicked_options=clicked_options, visited_asins=visited_asins)
|
220 |
+
end = time.time()
|
221 |
+
print("[Page Info -> WebShop HTML -> Observation] took", end-begin, "seconds")
|
222 |
+
|
223 |
+
# Dict of Info -> Valid Action State (Info)
|
224 |
+
begin = time.time()
|
225 |
+
prod_arg = product_map if page_type == Page.ITEM_PAGE else data
|
226 |
+
info = convert_dict_to_actions(page_type, prod_arg, asin, page_num, num_prods)
|
227 |
+
end = time.time()
|
228 |
+
print("Extracting available actions took", end-begin, "seconds")
|
229 |
+
|
230 |
+
if i == 99:
|
231 |
+
asin_url = f"https://www.amazon.com/dp/{asin}"
|
232 |
+
return_value = "Product URL: " + asin_url
|
233 |
+
if len(clicked_options) > 0:
|
234 |
+
options_str = ', '.join(list(clicked_options))
|
235 |
+
return_value += "\nSelected Options: " + options_str
|
236 |
+
return return_value
|
237 |
|
238 |
gr.Interface(fn=run_episode,\
|
239 |
inputs=gr.inputs.Textbox(lines=7, label="Input Text"),\
|
240 |
outputs="text",\
|
241 |
examples=[
|
242 |
+
"Please select a 1 pound, certified organic sea salt shaker in the flavor triple blend flakes, and price lower than 40.00 dollars",
|
243 |
"I want to find a gold floor lamp with a glass shade and a nickel finish that i can use for my living room, and price lower than 270.00 dollars",
|
244 |
"I'm trying to find white bluetooth speakers that are not only water resistant but also come with stereo sound",
|
245 |
"I'm looking for a kids toothbrush for ages 6 to 12 that will help with teeth whitening and is easy to use",
|
246 |
+
"I need some cute heart-shaped glittery cupcake picks as a gift to bring to a baby shower",
|
247 |
],\
|
248 |
title="WebShop",\
|
249 |
article="<p style='padding-top:15px;text-align:center;'>To learn more about this project, check out the <a href='https://webshop-pnlp.github.io/' target='_blank'>project page</a>!</p>",\
|
predict.py
DELETED
@@ -1,250 +0,0 @@
|
|
1 |
-
import time, torch
|
2 |
-
from transformers import BartTokenizer, BartForConditionalGeneration, AutoModel, AutoTokenizer
|
3 |
-
|
4 |
-
from webshop_lite import dict_to_fake_html
|
5 |
-
from predict_help import convert_dict_to_actions, convert_html_to_text, parse_results, parse_item_page, Page
|
6 |
-
|
7 |
-
# Configurations
|
8 |
-
DETAILED_OUTPUT = True
|
9 |
-
BART_MODEL_PATH = 'webshop/il_search_bart'
|
10 |
-
BERT_MODEL_PATH = 'webshop/il-rl-choice-bert-image_1'
|
11 |
-
|
12 |
-
# load IL models
|
13 |
-
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
14 |
-
bart_model = BartForConditionalGeneration.from_pretrained(BART_MODEL_PATH)
|
15 |
-
|
16 |
-
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')
|
17 |
-
bert_tokenizer.add_tokens(['[button]', '[button_]', '[clicked button]', '[clicked button_]'], special_tokens=True)
|
18 |
-
bert_model = AutoModel.from_pretrained(BERT_MODEL_PATH, trust_remote_code=True)
|
19 |
-
|
20 |
-
def process_str(s):
|
21 |
-
s = s.lower().replace('"', '').replace("'", "").strip()
|
22 |
-
s = s.replace('[sep]', '[SEP]')
|
23 |
-
return s
|
24 |
-
|
25 |
-
|
26 |
-
def process_goal(state):
|
27 |
-
state = state.lower().replace('"', '').replace("'", "")
|
28 |
-
state = state.replace('amazon shopping game\ninstruction:', '').replace('\n[button] search [button_]', '').strip()
|
29 |
-
if ', and price lower than' in state:
|
30 |
-
state = state.split(', and price lower than')[0]
|
31 |
-
return state
|
32 |
-
|
33 |
-
|
34 |
-
def data_collator(batch):
|
35 |
-
state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [], [], [], [], [], [], []
|
36 |
-
for sample in batch:
|
37 |
-
state_input_ids.append(sample['state_input_ids'])
|
38 |
-
state_attention_mask.append(sample['state_attention_mask'])
|
39 |
-
action_input_ids.extend(sample['action_input_ids'])
|
40 |
-
action_attention_mask.extend(sample['action_attention_mask'])
|
41 |
-
sizes.append(sample['sizes'])
|
42 |
-
labels.append(sample['labels'])
|
43 |
-
images.append(sample['images'])
|
44 |
-
max_state_len = max(sum(x) for x in state_attention_mask)
|
45 |
-
max_action_len = max(sum(x) for x in action_attention_mask)
|
46 |
-
return {
|
47 |
-
'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len],
|
48 |
-
'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len],
|
49 |
-
'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len],
|
50 |
-
'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len],
|
51 |
-
'sizes': torch.tensor(sizes),
|
52 |
-
'images': torch.tensor(images),
|
53 |
-
'labels': torch.tensor(labels),
|
54 |
-
}
|
55 |
-
|
56 |
-
|
57 |
-
def bart_predict(input):
|
58 |
-
input_ids = bart_tokenizer(input)['input_ids']
|
59 |
-
input_ids = torch.tensor(input_ids).unsqueeze(0)
|
60 |
-
output = bart_model.generate(input_ids, max_length=512, num_return_sequences=5, num_beams=5)
|
61 |
-
return bart_tokenizer.batch_decode(output.tolist(), skip_special_tokens=True)[0]
|
62 |
-
|
63 |
-
|
64 |
-
def bert_predict(obs, info, softmax=True):
|
65 |
-
valid_acts = info['valid']
|
66 |
-
assert valid_acts[0].startswith('click[')
|
67 |
-
state_encodings = bert_tokenizer(process_str(obs), max_length=512, truncation=True, padding='max_length')
|
68 |
-
action_encodings = bert_tokenizer(list(map(process_str, valid_acts)), max_length=512, truncation=True, padding='max_length')
|
69 |
-
batch = {
|
70 |
-
'state_input_ids': state_encodings['input_ids'],
|
71 |
-
'state_attention_mask': state_encodings['attention_mask'],
|
72 |
-
'action_input_ids': action_encodings['input_ids'],
|
73 |
-
'action_attention_mask': action_encodings['attention_mask'],
|
74 |
-
'sizes': len(valid_acts),
|
75 |
-
'images': info['image_feat'].tolist(),
|
76 |
-
'labels': 0
|
77 |
-
}
|
78 |
-
batch = data_collator([batch])
|
79 |
-
outputs = bert_model(**batch)
|
80 |
-
if softmax:
|
81 |
-
idx = torch.multinomial(torch.nn.functional.softmax(outputs.logits[0], dim=0), 1)[0].item()
|
82 |
-
else:
|
83 |
-
idx = outputs.logits[0].argmax(0).item()
|
84 |
-
return valid_acts[idx]
|
85 |
-
|
86 |
-
|
87 |
-
def predict(obs, info):
|
88 |
-
"""
|
89 |
-
Given WebShop environment observation and info, predict an action.
|
90 |
-
"""
|
91 |
-
valid_acts = info['valid']
|
92 |
-
if valid_acts[0].startswith('click['):
|
93 |
-
return bert_predict(obs, info)
|
94 |
-
else:
|
95 |
-
return "search[" + bart_predict(process_goal(obs)) + "]"
|
96 |
-
|
97 |
-
def run_episode(goal, verbose=True):
|
98 |
-
"""
|
99 |
-
Interact with amazon to find a product given input goal.
|
100 |
-
Input: text goal
|
101 |
-
Output: a url of found item on amazon.
|
102 |
-
"""
|
103 |
-
obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]"
|
104 |
-
info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)}
|
105 |
-
product_map = {}
|
106 |
-
title_to_asin_map = {}
|
107 |
-
search_results_cache = {}
|
108 |
-
visited_asins, clicked_options = set(), set()
|
109 |
-
sub_page_type, page_type, page_num = None, None, None
|
110 |
-
search_terms, prod_title, asin, num_prods, = None, None, None, None
|
111 |
-
options = {}
|
112 |
-
|
113 |
-
for i in range(100):
|
114 |
-
# Run prediction
|
115 |
-
action = predict(obs, info)
|
116 |
-
if verbose:
|
117 |
-
print("====\n" + action)
|
118 |
-
|
119 |
-
# Previous Page Type, Action -> Next Page Type
|
120 |
-
action_content = action[action.find("[")+1:action.find("]")]
|
121 |
-
prev_page_type = page_type
|
122 |
-
if action.startswith('search['):
|
123 |
-
page_type = Page.RESULTS
|
124 |
-
search_terms = action_content
|
125 |
-
page_num = 1
|
126 |
-
elif action.startswith('click['):
|
127 |
-
if action.startswith('click[item -'):
|
128 |
-
prod_title = action_content[len("item -"):].strip()
|
129 |
-
found = False
|
130 |
-
for key in title_to_asin_map:
|
131 |
-
if prod_title == key:
|
132 |
-
asin = title_to_asin_map[key]
|
133 |
-
page_type = Page.ITEM_PAGE
|
134 |
-
visited_asins.add(asin)
|
135 |
-
found = True
|
136 |
-
break
|
137 |
-
if not found:
|
138 |
-
raise Exception("Product to click not found")
|
139 |
-
|
140 |
-
elif any(x.value in action for x in [Page.DESC, Page.FEATURES, Page.REVIEWS]):
|
141 |
-
page_type = Page.SUB_PAGE
|
142 |
-
sub_page_type = Page(action_content.lower())
|
143 |
-
|
144 |
-
elif action == 'click[< prev]':
|
145 |
-
if sub_page_type is not None:
|
146 |
-
page_type, sub_page_type = Page.ITEM_PAGE, None
|
147 |
-
elif prev_page_type == Page.ITEM_PAGE:
|
148 |
-
page_type = Page.RESULTS
|
149 |
-
options, clicked_options = {}, set()
|
150 |
-
elif prev_page_type == Page.RESULTS and page_num > 1:
|
151 |
-
page_type = Page.RESULTS
|
152 |
-
page_num -= 1
|
153 |
-
|
154 |
-
elif action == 'click[next >]':
|
155 |
-
page_type = Page.RESULTS
|
156 |
-
page_num += 1
|
157 |
-
|
158 |
-
elif action.lower() == 'click[back to search]':
|
159 |
-
page_type = Page.SEARCH
|
160 |
-
|
161 |
-
elif action == 'click[buy now]':
|
162 |
-
if DETAILED_OUTPUT:
|
163 |
-
asin_url = f"https://www.amazon.com/dp/{asin}"
|
164 |
-
return_value = "Product URL: " + asin_url
|
165 |
-
if len(clicked_options) > 0:
|
166 |
-
options_str = ', '.join(list(clicked_options))
|
167 |
-
return_value += "\nSelected Options: " + options_str
|
168 |
-
return return_value
|
169 |
-
else:
|
170 |
-
return asin
|
171 |
-
|
172 |
-
elif prev_page_type == Page.ITEM_PAGE:
|
173 |
-
found = False
|
174 |
-
for opt_name, opt_values in product_map[asin]["options"].items():
|
175 |
-
if action_content in opt_values:
|
176 |
-
options[opt_name] = action_content
|
177 |
-
page_type = Page.ITEM_PAGE
|
178 |
-
clicked_options.add(action_content)
|
179 |
-
found = True
|
180 |
-
break
|
181 |
-
if not found:
|
182 |
-
raise Exception("Unrecognized action: " + action)
|
183 |
-
else:
|
184 |
-
raise Exception("Unrecognized action:" + action)
|
185 |
-
|
186 |
-
if verbose:
|
187 |
-
print(f"Parsing {page_type.value} page...")
|
188 |
-
|
189 |
-
# URL -> Real HTML -> Dict of Info
|
190 |
-
if page_type == Page.RESULTS:
|
191 |
-
if search_terms in search_results_cache:
|
192 |
-
data = search_results_cache[search_terms]
|
193 |
-
else:
|
194 |
-
begin = time.time()
|
195 |
-
data = parse_results(search_terms, page_num, verbose)
|
196 |
-
end = time.time()
|
197 |
-
if verbose:
|
198 |
-
print("Parsing search results took", end-begin, "seconds")
|
199 |
-
|
200 |
-
search_results_cache[search_terms] = data
|
201 |
-
num_prods = len(data)
|
202 |
-
for d in data:
|
203 |
-
title_to_asin_map[d['Title']] = d['asin']
|
204 |
-
elif page_type == Page.ITEM_PAGE or page_type == Page.SUB_PAGE:
|
205 |
-
if asin in product_map:
|
206 |
-
if verbose:
|
207 |
-
print("Loading cached item page for", asin)
|
208 |
-
data = product_map[asin]
|
209 |
-
else:
|
210 |
-
begin = time.time()
|
211 |
-
data = parse_item_page(asin, verbose)
|
212 |
-
end = time.time()
|
213 |
-
if verbose:
|
214 |
-
print("Parsing item page took", end-begin, "seconds")
|
215 |
-
product_map[asin] = data
|
216 |
-
elif page_type == Page.SEARCH:
|
217 |
-
if verbose:
|
218 |
-
print("Executing search")
|
219 |
-
obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]"
|
220 |
-
info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)}
|
221 |
-
continue
|
222 |
-
else:
|
223 |
-
raise Exception("Page of type `", page_type, "` not found")
|
224 |
-
|
225 |
-
# Dict of Info -> Fake HTML -> Text Observation
|
226 |
-
begin = time.time()
|
227 |
-
html_str = dict_to_fake_html(data, page_type, asin, sub_page_type, options, product_map, goal)
|
228 |
-
obs = convert_html_to_text(html_str, simple=False, clicked_options=clicked_options, visited_asins=visited_asins)
|
229 |
-
end = time.time()
|
230 |
-
if verbose:
|
231 |
-
print("[Page Info -> WebShop HTML -> Observation] took", end-begin, "seconds")
|
232 |
-
|
233 |
-
# Dict of Info -> Valid Action State (Info)
|
234 |
-
begin = time.time()
|
235 |
-
prod_arg = product_map if page_type == Page.ITEM_PAGE else data
|
236 |
-
info = convert_dict_to_actions(page_type, prod_arg, asin, page_num, num_prods)
|
237 |
-
end = time.time()
|
238 |
-
if verbose:
|
239 |
-
print("Extracting available actions took", end-begin, "seconds")
|
240 |
-
|
241 |
-
if i == 99:
|
242 |
-
if DETAILED_OUTPUT:
|
243 |
-
asin_url = f"https://www.amazon.com/dp/{asin}"
|
244 |
-
return_value = "Product URL: " + asin_url
|
245 |
-
if len(clicked_options) > 0:
|
246 |
-
options_str = ', '.join(list(clicked_options))
|
247 |
-
return_value += "\nSelected Options: " + options_str
|
248 |
-
return return_value
|
249 |
-
else:
|
250 |
-
return asin
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|