cahya commited on
Commit
2117b5e
·
1 Parent(s): 1afd82a

use news-api

Browse files
Files changed (1) hide show
  1. app/app.py +31 -47
app/app.py CHANGED
@@ -4,11 +4,11 @@ from mtranslate import translate
4
  from prompts import PROMPT_LIST
5
  import random
6
  import time
7
- from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer
8
  import psutil
9
  import torch
10
  import os
11
- from abstract_dataset import AbstractDataset
 
12
 
13
  # st.set_page_config(page_title="Indonesian GPT-2")
14
 
@@ -16,6 +16,7 @@ mirror_url = "https://news-generator.ai-research.id/"
16
  if "MIRROR_URL" in os.environ:
17
  mirror_url = os.environ["MIRROR_URL"]
18
  hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  MODELS = {
@@ -59,51 +60,39 @@ ___
59
  model_type = st.sidebar.selectbox('Model', (MODELS.keys()))
60
 
61
 
62
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
63
- def get_generator(model_name: str):
64
- st.write(f"Loading the GPT2 model {model_name}, please wait...")
65
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
66
- model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, use_auth_token=hf_auth_token)
67
- model.to(device)
68
- model.resize_token_embeddings(len(tokenizer))
69
- return model, tokenizer
70
-
71
-
72
  # Disable the st.cache for this function due to issue on newer version of streamlit
73
  # @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
74
- def process(text_generator, tokenizer, title: str, keywords: str, text: str,
75
  max_length: int = 250, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
76
  temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0,
77
  penalty_alpha = 0.6):
78
  # st.write("Cache miss: process")
79
- set_seed(seed)
80
- if repetition_penalty == 0.0:
81
- min_penalty = 1.05
82
- max_penalty = 1.5
83
- repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
84
- prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
85
-
86
- generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
87
- generated = generated.to(device)
88
-
89
- text_generator.eval()
90
- sample_outputs = text_generator.generate(generated,
91
- penalty_alpha=penalty_alpha,
92
- do_sample=do_sample,
93
- min_length=200,
94
- max_length=max_length,
95
- top_k=top_k,
96
- top_p=top_p,
97
- temperature=temperature,
98
- repetition_penalty=repetition_penalty,
99
- num_return_sequences=1
100
- )
101
- result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
102
- prefix_length = len(title) + len(keywords) + len("title: keywords: ") + 2
103
- result = result[prefix_length:]
104
- title_index = result.find("title: ")
105
- result = result[:title_index] if title_index > 0 else result
106
- return result
107
 
108
 
109
  st.title("Indonesian GPT-2 Applications")
@@ -215,10 +204,6 @@ if prompt_group_name in ["Indonesian Newspaper"]:
215
  max_value=2.0
216
  )
217
 
218
- for group_name in MODELS:
219
- if MODELS[group_name]["group"] in ["Indonesian Newspaper"]:
220
- MODELS[group_name]["text_generator"], MODELS[group_name]["tokenizer"] = \
221
- get_generator(MODELS[group_name]["name"])
222
  # st.write(f"Generator: {MODELS}'")
223
  if st.button("Run"):
224
  with st.spinner(text="Getting results..."):
@@ -226,8 +211,7 @@ if prompt_group_name in ["Indonesian Newspaper"]:
226
  st.subheader("Result")
227
  time_start = time.time()
228
  # text_generator = MODELS[model_type]["text_generator"]
229
- result = process(MODELS[model_type]["text_generator"], MODELS[model_type]["tokenizer"],
230
- title=session_state.title,
231
  keywords=session_state.keywords,
232
  text=session_state.text, max_length=int(max_length),
233
  temperature=temperature, do_sample=do_sample, penalty_alpha=penalty_alpha,
 
4
  from prompts import PROMPT_LIST
5
  import random
6
  import time
 
7
  import psutil
8
  import torch
9
  import os
10
+ import requests
11
+
12
 
13
  # st.set_page_config(page_title="Indonesian GPT-2")
14
 
 
16
  if "MIRROR_URL" in os.environ:
17
  mirror_url = os.environ["MIRROR_URL"]
18
  hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
19
+ news_api_auth_token = os.getenv("NEWS_API_AUTH_TOKEN", False)
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
  MODELS = {
 
60
  model_type = st.sidebar.selectbox('Model', (MODELS.keys()))
61
 
62
 
 
 
 
 
 
 
 
 
 
 
63
  # Disable the st.cache for this function due to issue on newer version of streamlit
64
  # @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
65
+ def process(title: str, keywords: str, text: str,
66
  max_length: int = 250, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
67
  temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0,
68
  penalty_alpha = 0.6):
69
  # st.write("Cache miss: process")
70
+ url = 'https://news-api.uncool.ai/api/text_generator/v1'
71
+ print("news_api_auth_token:", news_api_auth_token)
72
+ headers = {'Authorization': 'Bearer ' + news_api_auth_token}
73
+ print("Requesting to news-api.uncool.ai with headers: ", headers)
74
+ data = {
75
+ "title": title,
76
+ "keywords": keywords,
77
+ "text": text,
78
+ "max_length": max_length,
79
+ "do_sample": do_sample,
80
+ "top_k": top_k,
81
+ "top_p": top_p,
82
+ "temperature": temperature,
83
+ "max_time": max_time,
84
+ "seed": seed,
85
+ "repetition_penalty": repetition_penalty,
86
+ "penalty_alpha": penalty_alpha
87
+ }
88
+ r = requests.post(url, headers=headers, data=data)
89
+ if r.status_code == 200:
90
+ result = r.json()['generated_text']
91
+ title_index = result.find("title: ")
92
+ result = result[:title_index] if title_index > 0 else result
93
+ return result
94
+ else:
95
+ return "Error: " + r.text
 
 
96
 
97
 
98
  st.title("Indonesian GPT-2 Applications")
 
204
  max_value=2.0
205
  )
206
 
 
 
 
 
207
  # st.write(f"Generator: {MODELS}'")
208
  if st.button("Run"):
209
  with st.spinner(text="Getting results..."):
 
211
  st.subheader("Result")
212
  time_start = time.time()
213
  # text_generator = MODELS[model_type]["text_generator"]
214
+ result = process(title=session_state.title,
 
215
  keywords=session_state.keywords,
216
  text=session_state.text, max_length=int(max_length),
217
  temperature=temperature, do_sample=do_sample, penalty_alpha=penalty_alpha,