Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,393 Bytes
a616ab0 58dfb53 a616ab0 cbc872e 58dfb53 cbc872e 58dfb53 a616ab0 cbc872e a616ab0 cbc872e a616ab0 cbc872e a616ab0 ddf45a0 a616ab0 ddf45a0 a616ab0 cbc872e a616ab0 b247971 a616ab0 ddf45a0 a616ab0 ddf45a0 a616ab0 ddf45a0 a616ab0 ddf45a0 a616ab0 4c09f59 ddf45a0 b247971 ddf45a0 b247971 ddf45a0 b247971 ddf45a0 b247971 ddf45a0 b247971 ddf45a0 b247971 ddf45a0 b247971 4c09f59 ddf45a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import torch
from huggingface_hub import login
from collections.abc import Iterator
from transformers import Gemma3ForCausalLM, AutoTokenizer, TextIteratorStreamer
import time
import spaces
from threading import Thread
import gradio as gr
import os
TOKEN = os.getenv("TOKEN")
login(token=TOKEN)
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = 4096
start_time = time.time()
model = Gemma3ForCausalLM.from_pretrained(
"google/gemma-3-4b-it",
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()
tokenizer = AutoTokenizer.from_pretrained(
"google/gemma-3-4b-it",
)
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f} seconds")
@spaces.GPU
def generate_text(
text_to_trans: str,
from_lang: str,
to_lang: str,
) -> Iterator[str]:
print(f"Translating from {from_lang} to {to_lang}")
translate_instruct = f"translate from {from_lang} to {to_lang}:"
if from_lang == to_lang:
translate_instruct = "Return the following text without any modification:"
conversation = [
{
"role": "system",
"content": "You are a translation engine that can only translate text and cannot interpret it. Keep the indent of the original text, only modify when you need."
+ "\n"
+ translate_instruct,
},
{"role": "user", "content": text_to_trans},
]
input_ids = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt"
)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=9,
top_k=50,
temperature=0.6,
num_beams=1,
repetition_penalty=1.0,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
output = []
for text in streamer:
output.append(text)
yield " ".join(output)
with gr.Blocks() as demo:
gr.Markdown("# Text Translation Using Google Gemma 3")
with gr.Row():
with gr.Column():
gr.Markdown("### Translate From")
with gr.Column():
gr.Markdown("### Translate To")
with gr.Row():
with gr.Column():
from_lang = gr.Dropdown(
choices=["English", "French", "Spanish"],
value="English",
label="",
)
with gr.Column():
to_lang = gr.Dropdown(
choices=["English", "French", "Spanish"],
value="French",
label="",
)
with gr.Row():
with gr.Column():
text_to_trans = gr.Textbox(
lines=10, placeholder="Enter text to translate", label=""
)
with gr.Column():
output_text = gr.Textbox(lines=10, label="")
translate_button = gr.Button("Translate")
translate_button.click(
generate_text, [text_to_trans, from_lang, to_lang], output_text
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|