Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import BartTokenizer, BartForConditionalGeneration, AutoModel, AutoTokenizer
|
4 |
+
|
5 |
+
# load IL models
|
6 |
+
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
7 |
+
bart_model = BartForConditionalGeneration.from_pretrained('webshop/il_search_bart')
|
8 |
+
|
9 |
+
if False: # TODO: make sure it could be uploaded from hub
|
10 |
+
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')
|
11 |
+
bert_tokenizer.add_tokens(['[button]', '[button_]', '[clicked button]', '[clicked button_]'], special_tokens=True)
|
12 |
+
bert_model = AutoModel.from_pretrained('webshop/il_choice_bert')
|
13 |
+
|
14 |
+
|
15 |
+
def process_str(s):
|
16 |
+
s = s.lower().replace('"', '').replace("'", "").strip()
|
17 |
+
s = s.replace('[sep]', '[SEP]')
|
18 |
+
return s
|
19 |
+
|
20 |
+
|
21 |
+
def process_goal(state):
|
22 |
+
state = state.lower().replace('"', '').replace("'", "")
|
23 |
+
state = state.replace('amazon shopping game\ninstruction:', '').replace('\n[button] search [button_]', '').strip()
|
24 |
+
if ', and price lower than' in state:
|
25 |
+
state = state.split(', and price lower than')[0]
|
26 |
+
return state
|
27 |
+
|
28 |
+
|
29 |
+
def data_collator(batch):
|
30 |
+
state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [], [], [], [], [], [], []
|
31 |
+
for sample in batch:
|
32 |
+
state_input_ids.append(sample['state_input_ids'])
|
33 |
+
state_attention_mask.append(sample['state_attention_mask'])
|
34 |
+
action_input_ids.extend(sample['action_input_ids'])
|
35 |
+
action_attention_mask.extend(sample['action_attention_mask'])
|
36 |
+
sizes.append(sample['sizes'])
|
37 |
+
labels.append(sample['labels'])
|
38 |
+
images.append(sample['images'])
|
39 |
+
max_state_len = max(sum(x) for x in state_attention_mask)
|
40 |
+
max_action_len = max(sum(x) for x in action_attention_mask)
|
41 |
+
return {
|
42 |
+
'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len],
|
43 |
+
'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len],
|
44 |
+
'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len],
|
45 |
+
'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len],
|
46 |
+
'sizes': torch.tensor(sizes),
|
47 |
+
'images': torch.tensor(images),
|
48 |
+
'labels': torch.tensor(labels),
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
def bart_predict(input):
|
53 |
+
input_ids = bart_tokenizer(input)['input_ids']
|
54 |
+
input_ids = torch.tensor(input_ids).unsqueeze(0)
|
55 |
+
output = bart_model.generate(input_ids, max_length=512, num_return_sequences=5, num_beams=5)
|
56 |
+
return bart_tokenizer.batch_decode(output.tolist(), skip_special_tokens=True)[0]
|
57 |
+
|
58 |
+
|
59 |
+
def bert_predict(obs, info, softmax=True):
|
60 |
+
valid_acts = info['valid']
|
61 |
+
assert valid_acts[0].startswith('click[')
|
62 |
+
state_encodings = bert_tokenizer(process_str(obs), max_length=512, truncation=True, padding='max_length')
|
63 |
+
action_encodings = bert_tokenizer(list(map(process_str, valid_acts)), max_length=512, truncation=True, padding='max_length')
|
64 |
+
batch = {
|
65 |
+
'state_input_ids': state_encodings['input_ids'],
|
66 |
+
'state_attention_mask': state_encodings['attention_mask'],
|
67 |
+
'action_input_ids': action_encodings['input_ids'],
|
68 |
+
'action_attention_mask': action_encodings['attention_mask'],
|
69 |
+
'sizes': len(valid_acts),
|
70 |
+
'images': info['image_feat'].tolist(),
|
71 |
+
'labels': 0
|
72 |
+
}
|
73 |
+
batch = data_collator([batch])
|
74 |
+
outputs = bert_model(**batch)
|
75 |
+
if softmax:
|
76 |
+
idx = torch.multinomial(torch.nn.functional.softmax(outputs.logits[0], dim=0), 1)[0].item()
|
77 |
+
else:
|
78 |
+
idx = outputs.logits[0].argmax(0).item()
|
79 |
+
return valid_acts[idx]
|
80 |
+
|
81 |
+
|
82 |
+
def predict(obs, info):
|
83 |
+
"""
|
84 |
+
Given WebShop environment observation and info, predict an action.
|
85 |
+
"""
|
86 |
+
valid_acts = info['valid']
|
87 |
+
if valid_acts[0].startswith('click['):
|
88 |
+
return bert_predict(obs, info)
|
89 |
+
else:
|
90 |
+
return bart_predict(process_goal(obs))
|
91 |
+
|
92 |
+
|
93 |
+
def run_episode(goal):
|
94 |
+
"""
|
95 |
+
Interact with amazon to find a product given input goal.
|
96 |
+
Input: text goal
|
97 |
+
Output: a url of found item on amazon.
|
98 |
+
"""
|
99 |
+
return bart_predict(goal) # TODO: implement run_episode
|
100 |
+
|
101 |
+
|
102 |
+
gr.Interface(fn=run_episode, inputs=gr.inputs.Textbox(
|
103 |
+
lines=7, label="Input Text"), outputs="text").launch(inline=False)
|