Commit
·
24c8264
1
Parent(s):
7708da3
adding special tokens
Browse files
app.py
CHANGED
@@ -21,8 +21,11 @@ tokenizer = en_model.tokenizer
|
|
21 |
model = en_model.model
|
22 |
device = model.device
|
23 |
|
|
|
|
|
|
|
24 |
def classify(sentiment):
|
25 |
-
parameters = {'n_sentences':
|
26 |
'batch_size': 2,
|
27 |
'avg_len':30,
|
28 |
'max_len':50,
|
@@ -30,7 +33,7 @@ def classify(sentiment):
|
|
30 |
'generation_method':'parallel',
|
31 |
'sample': True,
|
32 |
'burnin': 450,
|
33 |
-
'max_iter':
|
34 |
'top_k': 100,
|
35 |
'seed_text': f"[{sentiment}-0] [{sentiment}-1] [{sentiment}-2]",
|
36 |
# 'verbose': True
|
@@ -44,7 +47,7 @@ demo = gr.Blocks()
|
|
44 |
|
45 |
with demo:
|
46 |
gr.Markdown()
|
47 |
-
inputs = gr.Dropdown(
|
48 |
output = gr.Textbox(label="Generated tweet")
|
49 |
b1 = gr.Button("Generate")
|
50 |
b1.click(classify, inputs=inputs, outputs=output)
|
|
|
21 |
model = en_model.model
|
22 |
device = model.device
|
23 |
|
24 |
+
en_model.tokenizer.add_special_tokens({'additional_special_tokens': ['[POSITIVE-0]', '[POSITIVE-1]', '[POSITIVE-2]','[NEGATIVE-0]', '[NEGATIVE-1]', '[NEGATIVE-2]']})
|
25 |
+
en_model.model.resize_token_embeddings(len(en_model.tokenizer))
|
26 |
+
|
27 |
def classify(sentiment):
|
28 |
+
parameters = {'n_sentences': 5,
|
29 |
'batch_size': 2,
|
30 |
'avg_len':30,
|
31 |
'max_len':50,
|
|
|
33 |
'generation_method':'parallel',
|
34 |
'sample': True,
|
35 |
'burnin': 450,
|
36 |
+
'max_iter': 150,
|
37 |
'top_k': 100,
|
38 |
'seed_text': f"[{sentiment}-0] [{sentiment}-1] [{sentiment}-2]",
|
39 |
# 'verbose': True
|
|
|
47 |
|
48 |
with demo:
|
49 |
gr.Markdown()
|
50 |
+
inputs = gr.Dropdown(["POSITIVE", "NEGATIVE"], label="Sentiment to generate")
|
51 |
output = gr.Textbox(label="Generated tweet")
|
52 |
b1 = gr.Button("Generate")
|
53 |
b1.click(classify, inputs=inputs, outputs=output)
|