Spaces:
Runtime error
Runtime error
remove print out hf_
Browse files- app/app.py +3 -16
app/app.py
CHANGED
@@ -16,7 +16,6 @@ mirror_url = "https://news-generator.ai-research.id/"
|
|
16 |
if "MIRROR_URL" in os.environ:
|
17 |
mirror_url = os.environ["MIRROR_URL"]
|
18 |
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
|
19 |
-
st.write(f"Using Hugging Face auth token: {hf_auth_token[:10]}...")
|
20 |
|
21 |
MODELS = {
|
22 |
"Indonesian Newspaper - Indonesian GPT-2 Medium": {
|
@@ -80,24 +79,12 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
|
|
80 |
min_penalty = 1.05
|
81 |
max_penalty = 1.5
|
82 |
repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
|
83 |
-
# print("title:", title)
|
84 |
-
# print("keywords:", keywords)
|
85 |
prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
|
86 |
-
# print("prompt: ", prompt)
|
87 |
|
88 |
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
89 |
# device = torch.device("cuda")
|
90 |
# generated = generated.to(device)
|
91 |
|
92 |
-
print("do_sample:", do_sample)
|
93 |
-
print("penalty_alpha:", penalty_alpha)
|
94 |
-
print("max_length:", max_length)
|
95 |
-
print("top_k:", top_k)
|
96 |
-
print("top_p:", top_p)
|
97 |
-
print("temperature:", temperature)
|
98 |
-
print("max_time:", max_time)
|
99 |
-
print("repetition_penalty:", repetition_penalty)
|
100 |
-
|
101 |
text_generator.eval()
|
102 |
sample_outputs = text_generator.generate(generated,
|
103 |
penalty_alpha=penalty_alpha,
|
@@ -111,16 +98,17 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
|
|
111 |
num_return_sequences=1
|
112 |
)
|
113 |
result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
|
114 |
-
# print(f"result: {result}")
|
115 |
prefix_length = len(title) + len(keywords) + len("title: keywords: ") + 2
|
116 |
result = result[prefix_length:]
|
|
|
|
|
117 |
return result
|
118 |
|
119 |
|
120 |
st.title("Indonesian GPT-2 Applications")
|
121 |
prompt_group_name = MODELS[model_type]["group"]
|
122 |
st.header(prompt_group_name)
|
123 |
-
description = f"This is a
|
124 |
st.markdown(description)
|
125 |
model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co/{MODELS[model_type]['name']})"
|
126 |
st.markdown(model_name)
|
@@ -246,7 +234,6 @@ if prompt_group_name in ["Indonesian Newspaper"]:
|
|
246 |
time_end = time.time()
|
247 |
time_diff = time_end - time_start
|
248 |
# result = result[0]["generated_text"]
|
249 |
-
result = result[:result.find("title:")]
|
250 |
st.write(result.replace("\n", " \n"))
|
251 |
st.text("Translation")
|
252 |
translation = translate(result, "en", "id")
|
|
|
16 |
if "MIRROR_URL" in os.environ:
|
17 |
mirror_url = os.environ["MIRROR_URL"]
|
18 |
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
|
|
|
19 |
|
20 |
MODELS = {
|
21 |
"Indonesian Newspaper - Indonesian GPT-2 Medium": {
|
|
|
79 |
min_penalty = 1.05
|
80 |
max_penalty = 1.5
|
81 |
repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
|
|
|
|
|
82 |
prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
|
|
|
83 |
|
84 |
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
85 |
# device = torch.device("cuda")
|
86 |
# generated = generated.to(device)
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
text_generator.eval()
|
89 |
sample_outputs = text_generator.generate(generated,
|
90 |
penalty_alpha=penalty_alpha,
|
|
|
98 |
num_return_sequences=1
|
99 |
)
|
100 |
result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
|
|
|
101 |
prefix_length = len(title) + len(keywords) + len("title: keywords: ") + 2
|
102 |
result = result[prefix_length:]
|
103 |
+
title_index = result.find("title: ")
|
104 |
+
result = result[:title_index] if title_index > 0 else result
|
105 |
return result
|
106 |
|
107 |
|
108 |
st.title("Indonesian GPT-2 Applications")
|
109 |
prompt_group_name = MODELS[model_type]["group"]
|
110 |
st.header(prompt_group_name)
|
111 |
+
description = f"This is a news generator using Indonesian GPT-2 Medium. We finetuned the pre-trained model with the Indonesian online newspaper dataset."
|
112 |
st.markdown(description)
|
113 |
model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co/{MODELS[model_type]['name']})"
|
114 |
st.markdown(model_name)
|
|
|
234 |
time_end = time.time()
|
235 |
time_diff = time_end - time_start
|
236 |
# result = result[0]["generated_text"]
|
|
|
237 |
st.write(result.replace("\n", " \n"))
|
238 |
st.text("Translation")
|
239 |
translation = translate(result, "en", "id")
|