JuanJoseMV commited on
Commit
8e7d1f2
·
1 Parent(s): 46677b4

Adding pre-trained bert

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -2,14 +2,22 @@ import gradio as gr
2
  from NeuralTextGenerator import BertTextGenerator
3
 
4
  # Load models
5
- model_name = "cardiffnlp/twitter-xlm-roberta-base"
6
- en_model = BertTextGenerator(model_name, tokenizer=model_name)
7
 
 
 
 
 
 
 
 
 
 
8
  finetunned_BERT_model_name = "JuanJoseMV/BERT_text_gen"
9
- finetunned_BERT_en_model = BertTextGenerator(finetunned_BERT_model_name, tokenizer='bert-base-uncased')
10
 
 
11
  finetunned_RoBERTa_model_name = "JuanJoseMV/XLM_RoBERTa_text_gen"
12
- finetunned_RoBERTa_en_model = BertTextGenerator(finetunned_RoBERTa_model_name, tokenizer=finetunned_RoBERTa_model_name)
13
 
14
  special_tokens = [
15
  '[POSITIVE-0]',
@@ -20,23 +28,19 @@ special_tokens = [
20
  '[NEGATIVE-2]'
21
  ]
22
 
23
- # en_model.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
24
- # en_model.model.resize_token_embeddings(len(en_model.tokenizer))
25
-
26
- finetunned_BERT_en_model.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
27
- finetunned_BERT_en_model.model.resize_token_embeddings(len(en_model.tokenizer))
28
-
29
- # finetunned_RoBERTa_en_model.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
30
- # finetunned_RoBERTa_en_model.model.resize_token_embeddings(len(en_model.tokenizer))
31
 
32
  def sentence_builder(selected_model, n_sentences, max_iter, sentiment, seed_text):
33
 
34
- if selected_model == "Finetuned_RoBERTA":
35
- generator = finetunned_RoBERTa_en_model
36
  elif selected_model == "Finetuned_BERT":
37
- generator = finetunned_BERT_en_model
 
 
38
  else:
39
- generator = en_model
40
 
41
  parameters = {'n_sentences': n_sentences,
42
  'batch_size': 2,
@@ -63,7 +67,7 @@ def sentence_builder(selected_model, n_sentences, max_iter, sentiment, seed_text
63
  demo = gr.Interface(
64
  sentence_builder,
65
  [
66
- gr.Radio(["Pre-trained", "Finetuned_RoBERTA", "Finetunned_BERT"], value="Pre-trained", label="Sentiment to generate"),
67
  gr.Slider(1, 15, value=2, label="Num. Tweets", step=1, info="Number of tweets to be generated."),
68
  gr.Slider(50, 500, value=100, label="Max. iter", info="Maximum number of iterations for the generation."),
69
  gr.Radio(["POSITIVE", "NEGATIVE"], value="POSITIVE", label="Sentiment to generate"),
 
2
  from NeuralTextGenerator import BertTextGenerator
3
 
4
  # Load models
 
 
5
 
6
+ ## BERT
7
+ BERT_model_name = "Twitter/twhin-bert-large"
8
+ BERT = BertTextGenerator(BERT_model_name, tokenizer=BERT_model_name)
9
+
10
+ ## RoBERTa
11
+ RoBERTa_model_name = "cardiffnlp/twitter-xlm-roberta-base"
12
+ RoBERTa = BertTextGenerator(RoBERTa_model_name, tokenizer=RoBERTa_model_name)
13
+
14
+ ## Finetuned BERT
15
  finetunned_BERT_model_name = "JuanJoseMV/BERT_text_gen"
16
+ finetunned_BERT = BertTextGenerator(finetunned_BERT_model_name, tokenizer='bert-base-uncased')
17
 
18
+ ## Finetuned RoBERTa
19
  finetunned_RoBERTa_model_name = "JuanJoseMV/XLM_RoBERTa_text_gen"
20
+ finetunned_RoBERTa = BertTextGenerator(finetunned_RoBERTa_model_name, tokenizer=finetunned_RoBERTa_model_name)
21
 
22
  special_tokens = [
23
  '[POSITIVE-0]',
 
28
  '[NEGATIVE-2]'
29
  ]
30
 
31
+ finetunned_BERT.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
32
+ finetunned_BERT.model.resize_token_embeddings(len(finetunned_BERT.tokenizer))
 
 
 
 
 
 
33
 
34
  def sentence_builder(selected_model, n_sentences, max_iter, sentiment, seed_text):
35
 
36
+ if selected_model == "Finetuned_RoBERTa":
37
+ generator = finetunned_RoBERTa
38
  elif selected_model == "Finetuned_BERT":
39
+ generator = finetunned_BERT
40
+ elif selected_model == "RoBERTa":
41
+ generator = RoBERTa
42
  else:
43
+ generator = BERT
44
 
45
  parameters = {'n_sentences': n_sentences,
46
  'batch_size': 2,
 
67
  demo = gr.Interface(
68
  sentence_builder,
69
  [
70
+ gr.Radio(["BERT", "RoBERTa", "Finetuned_RoBERTa", "Finetunned_BERT"], value="BERT", label="Generator model"),
71
  gr.Slider(1, 15, value=2, label="Num. Tweets", step=1, info="Number of tweets to be generated."),
72
  gr.Slider(50, 500, value=100, label="Max. iter", info="Maximum number of iterations for the generation."),
73
  gr.Radio(["POSITIVE", "NEGATIVE"], value="POSITIVE", label="Sentiment to generate"),