Spaces:
Runtime error
Runtime error
use news-api
Browse files- app/app.py +31 -47
app/app.py
CHANGED
@@ -4,11 +4,11 @@ from mtranslate import translate
|
|
4 |
from prompts import PROMPT_LIST
|
5 |
import random
|
6 |
import time
|
7 |
-
from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer
|
8 |
import psutil
|
9 |
import torch
|
10 |
import os
|
11 |
-
|
|
|
12 |
|
13 |
# st.set_page_config(page_title="Indonesian GPT-2")
|
14 |
|
@@ -16,6 +16,7 @@ mirror_url = "https://news-generator.ai-research.id/"
|
|
16 |
if "MIRROR_URL" in os.environ:
|
17 |
mirror_url = os.environ["MIRROR_URL"]
|
18 |
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
|
|
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
|
21 |
MODELS = {
|
@@ -59,51 +60,39 @@ ___
|
|
59 |
model_type = st.sidebar.selectbox('Model', (MODELS.keys()))
|
60 |
|
61 |
|
62 |
-
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
63 |
-
def get_generator(model_name: str):
|
64 |
-
st.write(f"Loading the GPT2 model {model_name}, please wait...")
|
65 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
|
66 |
-
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, use_auth_token=hf_auth_token)
|
67 |
-
model.to(device)
|
68 |
-
model.resize_token_embeddings(len(tokenizer))
|
69 |
-
return model, tokenizer
|
70 |
-
|
71 |
-
|
72 |
# Disable the st.cache for this function due to issue on newer version of streamlit
|
73 |
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
74 |
-
def process(
|
75 |
max_length: int = 250, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
76 |
temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0,
|
77 |
penalty_alpha = 0.6):
|
78 |
# st.write("Cache miss: process")
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
result = result[:title_index] if title_index > 0 else result
|
106 |
-
return result
|
107 |
|
108 |
|
109 |
st.title("Indonesian GPT-2 Applications")
|
@@ -215,10 +204,6 @@ if prompt_group_name in ["Indonesian Newspaper"]:
|
|
215 |
max_value=2.0
|
216 |
)
|
217 |
|
218 |
-
for group_name in MODELS:
|
219 |
-
if MODELS[group_name]["group"] in ["Indonesian Newspaper"]:
|
220 |
-
MODELS[group_name]["text_generator"], MODELS[group_name]["tokenizer"] = \
|
221 |
-
get_generator(MODELS[group_name]["name"])
|
222 |
# st.write(f"Generator: {MODELS}'")
|
223 |
if st.button("Run"):
|
224 |
with st.spinner(text="Getting results..."):
|
@@ -226,8 +211,7 @@ if prompt_group_name in ["Indonesian Newspaper"]:
|
|
226 |
st.subheader("Result")
|
227 |
time_start = time.time()
|
228 |
# text_generator = MODELS[model_type]["text_generator"]
|
229 |
-
result = process(
|
230 |
-
title=session_state.title,
|
231 |
keywords=session_state.keywords,
|
232 |
text=session_state.text, max_length=int(max_length),
|
233 |
temperature=temperature, do_sample=do_sample, penalty_alpha=penalty_alpha,
|
|
|
4 |
from prompts import PROMPT_LIST
|
5 |
import random
|
6 |
import time
|
|
|
7 |
import psutil
|
8 |
import torch
|
9 |
import os
|
10 |
+
import requests
|
11 |
+
|
12 |
|
13 |
# st.set_page_config(page_title="Indonesian GPT-2")
|
14 |
|
|
|
16 |
if "MIRROR_URL" in os.environ:
|
17 |
mirror_url = os.environ["MIRROR_URL"]
|
18 |
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
|
19 |
+
news_api_auth_token = os.getenv("NEWS_API_AUTH_TOKEN", False)
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
MODELS = {
|
|
|
60 |
model_type = st.sidebar.selectbox('Model', (MODELS.keys()))
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
# Disable the st.cache for this function due to issue on newer version of streamlit
|
64 |
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
65 |
+
def process(title: str, keywords: str, text: str,
|
66 |
max_length: int = 250, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
67 |
temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0,
|
68 |
penalty_alpha = 0.6):
|
69 |
# st.write("Cache miss: process")
|
70 |
+
url = 'https://news-api.uncool.ai/api/text_generator/v1'
|
71 |
+
print("news_api_auth_token:", news_api_auth_token)
|
72 |
+
headers = {'Authorization': 'Bearer ' + news_api_auth_token}
|
73 |
+
print("Requesting to news-api.uncool.ai with headers: ", headers)
|
74 |
+
data = {
|
75 |
+
"title": title,
|
76 |
+
"keywords": keywords,
|
77 |
+
"text": text,
|
78 |
+
"max_length": max_length,
|
79 |
+
"do_sample": do_sample,
|
80 |
+
"top_k": top_k,
|
81 |
+
"top_p": top_p,
|
82 |
+
"temperature": temperature,
|
83 |
+
"max_time": max_time,
|
84 |
+
"seed": seed,
|
85 |
+
"repetition_penalty": repetition_penalty,
|
86 |
+
"penalty_alpha": penalty_alpha
|
87 |
+
}
|
88 |
+
r = requests.post(url, headers=headers, data=data)
|
89 |
+
if r.status_code == 200:
|
90 |
+
result = r.json()['generated_text']
|
91 |
+
title_index = result.find("title: ")
|
92 |
+
result = result[:title_index] if title_index > 0 else result
|
93 |
+
return result
|
94 |
+
else:
|
95 |
+
return "Error: " + r.text
|
|
|
|
|
96 |
|
97 |
|
98 |
st.title("Indonesian GPT-2 Applications")
|
|
|
204 |
max_value=2.0
|
205 |
)
|
206 |
|
|
|
|
|
|
|
|
|
207 |
# st.write(f"Generator: {MODELS}'")
|
208 |
if st.button("Run"):
|
209 |
with st.spinner(text="Getting results..."):
|
|
|
211 |
st.subheader("Result")
|
212 |
time_start = time.time()
|
213 |
# text_generator = MODELS[model_type]["text_generator"]
|
214 |
+
result = process(title=session_state.title,
|
|
|
215 |
keywords=session_state.keywords,
|
216 |
text=session_state.text, max_length=int(max_length),
|
217 |
temperature=temperature, do_sample=do_sample, penalty_alpha=penalty_alpha,
|