|
import os |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
|
|
|
import re |
|
import gradio as gr |
|
from NeuralTextGenerator import BertTextGenerator |
|
|
|
|
|
|
|
special_tokens = [ |
|
'[POSITIVE-0]', |
|
'[POSITIVE-1]', |
|
'[POSITIVE-2]', |
|
'[NEGATIVE-0]', |
|
'[NEGATIVE-1]', |
|
'[NEGATIVE-2]' |
|
] |
|
|
|
|
|
finetunned_RoBERTa_model_name = "JuanJoseMV/XLM_RoBERTa_text_gen" |
|
finetunned_RoBERTa = BertTextGenerator(finetunned_RoBERTa_model_name) |
|
|
|
finetunned_RoBERTa.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens}) |
|
finetunned_RoBERTa.model.resize_token_embeddings(len(finetunned_RoBERTa.tokenizer)) |
|
|
|
|
|
finetunned_RoBERTa_Hate_model_name = "JuanJoseMV/XLM_RoBERTa_text_gen_FT_Hate" |
|
finetunned_RoBERTa_Hate = BertTextGenerator(finetunned_RoBERTa_Hate_model_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RoBERTa_model_name = "cardiffnlp/twitter-xlm-roberta-base" |
|
RoBERTa = BertTextGenerator(RoBERTa_model_name) |
|
|
|
|
|
BERT_model_name = "Twitter/twhin-bert-large" |
|
BERT = BertTextGenerator(BERT_model_name) |
|
|
|
def sentence_builder( |
|
selected_model, |
|
n_sentences, |
|
max_iter, |
|
temperature, |
|
top_k, |
|
sentiment, |
|
seed_text |
|
): |
|
|
|
if selected_model == "Finetuned_RoBERTa": |
|
generator = finetunned_RoBERTa |
|
elif selected_model == "Finetuned_RoBERTa_Hate": |
|
generator = finetunned_RoBERTa_Hate |
|
sentiment = 'HATE' |
|
if selected_model == "RoBERTa": |
|
generator = RoBERTa |
|
else: |
|
generator = BERT |
|
|
|
|
|
parameters = {'n_sentences': n_sentences, |
|
'batch_size': n_sentences if n_sentences < 10 else 10, |
|
'avg_len':30, |
|
'max_len':50, |
|
'std_len' : 3, |
|
'generation_method':'parallel', |
|
'sample': True, |
|
'burnin': 450, |
|
'max_iter': max_iter, |
|
'top_k': top_k, |
|
'seed_text': f"[{sentiment}-0] [{sentiment}-1] [{sentiment}-2] {seed_text}", |
|
'temperature': temperature, |
|
'verbose': True |
|
} |
|
sents = generator.generate(**parameters) |
|
|
|
|
|
gen_text = '' |
|
for i, s in enumerate(sents): |
|
clean_sent = re.sub(r'\[.*?\]', '', s) |
|
gen_text += f'- GENERATED TWEET #{i + 1}: {clean_sent}\n\n' |
|
|
|
return gen_text |
|
|
|
|
|
demo = gr.Interface( |
|
sentence_builder, |
|
[ |
|
gr.Radio(["BERT", "RoBERTa", "Finetuned_RoBERTa", "Finetuned_RoBERTa_Hate"], value="RoBERTa", label="Generator model"), |
|
|
|
gr.Slider(1, 15, value=5, label="Num. Tweets", step=1, info="Number of tweets to be generated."), |
|
gr.Slider(50, 500, value=300, label="Max. iter", info="Maximum number of iterations for the generation."), |
|
gr.Slider(0, 1.0, value=0.8, step=0.05, label="Temperature", info="Temperature parameter for the generation."), |
|
gr.Slider(1, 200, value=130, step=1, label="Top k", info="Top k parameter for the generation."), |
|
gr.Radio(["POSITIVE", "NEGATIVE"], value="NEGATIVE", label="Sentiment to generate"), |
|
gr.Textbox('ATP Finals in Turin', label="Seed text", info="Seed text for the generation.") |
|
], |
|
"text", |
|
) |
|
|
|
|
|
demo.launch() |