webshop commited on
Commit
1bdaecc
·
1 Parent(s): 851fd13

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import BartTokenizer, BartForConditionalGeneration, AutoModel, AutoTokenizer
4
+
5
+ # load IL models
6
+ bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
7
+ bart_model = BartForConditionalGeneration.from_pretrained('webshop/il_search_bart')
8
+
9
+ if False: # TODO: make sure it could be uploaded from hub
10
+ bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')
11
+ bert_tokenizer.add_tokens(['[button]', '[button_]', '[clicked button]', '[clicked button_]'], special_tokens=True)
12
+ bert_model = AutoModel.from_pretrained('webshop/il_choice_bert')
13
+
14
+
15
+ def process_str(s):
16
+ s = s.lower().replace('"', '').replace("'", "").strip()
17
+ s = s.replace('[sep]', '[SEP]')
18
+ return s
19
+
20
+
21
+ def process_goal(state):
22
+ state = state.lower().replace('"', '').replace("'", "")
23
+ state = state.replace('amazon shopping game\ninstruction:', '').replace('\n[button] search [button_]', '').strip()
24
+ if ', and price lower than' in state:
25
+ state = state.split(', and price lower than')[0]
26
+ return state
27
+
28
+
29
+ def data_collator(batch):
30
+ state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [], [], [], [], [], [], []
31
+ for sample in batch:
32
+ state_input_ids.append(sample['state_input_ids'])
33
+ state_attention_mask.append(sample['state_attention_mask'])
34
+ action_input_ids.extend(sample['action_input_ids'])
35
+ action_attention_mask.extend(sample['action_attention_mask'])
36
+ sizes.append(sample['sizes'])
37
+ labels.append(sample['labels'])
38
+ images.append(sample['images'])
39
+ max_state_len = max(sum(x) for x in state_attention_mask)
40
+ max_action_len = max(sum(x) for x in action_attention_mask)
41
+ return {
42
+ 'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len],
43
+ 'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len],
44
+ 'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len],
45
+ 'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len],
46
+ 'sizes': torch.tensor(sizes),
47
+ 'images': torch.tensor(images),
48
+ 'labels': torch.tensor(labels),
49
+ }
50
+
51
+
52
+ def bart_predict(input):
53
+ input_ids = bart_tokenizer(input)['input_ids']
54
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
55
+ output = bart_model.generate(input_ids, max_length=512, num_return_sequences=5, num_beams=5)
56
+ return bart_tokenizer.batch_decode(output.tolist(), skip_special_tokens=True)[0]
57
+
58
+
59
+ def bert_predict(obs, info, softmax=True):
60
+ valid_acts = info['valid']
61
+ assert valid_acts[0].startswith('click[')
62
+ state_encodings = bert_tokenizer(process_str(obs), max_length=512, truncation=True, padding='max_length')
63
+ action_encodings = bert_tokenizer(list(map(process_str, valid_acts)), max_length=512, truncation=True, padding='max_length')
64
+ batch = {
65
+ 'state_input_ids': state_encodings['input_ids'],
66
+ 'state_attention_mask': state_encodings['attention_mask'],
67
+ 'action_input_ids': action_encodings['input_ids'],
68
+ 'action_attention_mask': action_encodings['attention_mask'],
69
+ 'sizes': len(valid_acts),
70
+ 'images': info['image_feat'].tolist(),
71
+ 'labels': 0
72
+ }
73
+ batch = data_collator([batch])
74
+ outputs = bert_model(**batch)
75
+ if softmax:
76
+ idx = torch.multinomial(torch.nn.functional.softmax(outputs.logits[0], dim=0), 1)[0].item()
77
+ else:
78
+ idx = outputs.logits[0].argmax(0).item()
79
+ return valid_acts[idx]
80
+
81
+
82
+ def predict(obs, info):
83
+ """
84
+ Given WebShop environment observation and info, predict an action.
85
+ """
86
+ valid_acts = info['valid']
87
+ if valid_acts[0].startswith('click['):
88
+ return bert_predict(obs, info)
89
+ else:
90
+ return bart_predict(process_goal(obs))
91
+
92
+
93
+ def run_episode(goal):
94
+ """
95
+ Interact with amazon to find a product given input goal.
96
+ Input: text goal
97
+ Output: a url of found item on amazon.
98
+ """
99
+ return bart_predict(goal) # TODO: implement run_episode
100
+
101
+
102
+ gr.Interface(fn=run_episode, inputs=gr.inputs.Textbox(
103
+ lines=7, label="Input Text"), outputs="text").launch(inline=False)