File size: 3,393 Bytes
a616ab0
58dfb53
a616ab0
 
 
 
 
cbc872e
58dfb53
cbc872e
58dfb53
 
a616ab0
 
 
cbc872e
a616ab0
 
 
 
 
 
cbc872e
a616ab0
 
 
 
 
cbc872e
 
a616ab0
 
ddf45a0
 
 
a616ab0
ddf45a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a616ab0
 
 
 
 
 
cbc872e
a616ab0
b247971
a616ab0
 
 
 
ddf45a0
a616ab0
ddf45a0
 
 
a616ab0
ddf45a0
a616ab0
 
 
ddf45a0
a616ab0
 
 
4c09f59
 
 
ddf45a0
b247971
 
 
 
 
ddf45a0
 
 
 
 
b247971
ddf45a0
 
 
b247971
 
 
 
ddf45a0
 
 
b247971
 
 
 
ddf45a0
 
b247971
 
 
ddf45a0
b247971
 
ddf45a0
 
 
b247971
4c09f59
 
ddf45a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from huggingface_hub import login
from collections.abc import Iterator
from transformers import Gemma3ForCausalLM, AutoTokenizer, TextIteratorStreamer
import time
import spaces
from threading import Thread
import gradio as gr
import os

TOKEN = os.getenv("TOKEN")
login(token=TOKEN)
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = 4096

start_time = time.time()
model = Gemma3ForCausalLM.from_pretrained(
    "google/gemma-3-4b-it",
    torch_dtype=torch.bfloat16,
    device_map="auto",
).eval()

tokenizer = AutoTokenizer.from_pretrained(
    "google/gemma-3-4b-it",
)
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f} seconds")


@spaces.GPU
def generate_text(
    text_to_trans: str,
    from_lang: str,
    to_lang: str,
) -> Iterator[str]:
    print(f"Translating from {from_lang} to {to_lang}")

    translate_instruct = f"translate from {from_lang} to {to_lang}:"

    if from_lang == to_lang:
        translate_instruct = "Return the following text without any modification:"

    conversation = [
        {
            "role": "system",
            "content": "You are a translation engine that can only translate text and cannot interpret it. Keep the indent of the original text, only modify when you need."
            + "\n"
            + translate_instruct,
        },
        {"role": "user", "content": text_to_trans},
    ]
    input_ids = tokenizer.apply_chat_template(
        conversation, add_generation_prompt=True, return_tensors="pt"
    )
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=9,
        top_k=50,
        temperature=0.6,
        num_beams=1,
        repetition_penalty=1.0,
    )
    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    output = []
    for text in streamer:
        output.append(text)
        yield " ".join(output)


with gr.Blocks() as demo:
    gr.Markdown("# Text Translation Using Google Gemma 3")

    with gr.Row():
        with gr.Column():
            gr.Markdown("### Translate From")
        with gr.Column():
            gr.Markdown("### Translate To")

    with gr.Row():
        with gr.Column():
            from_lang = gr.Dropdown(
                choices=["English", "French", "Spanish"],
                value="English",
                label="",
            )

        with gr.Column():
            to_lang = gr.Dropdown(
                choices=["English", "French", "Spanish"],
                value="French",
                label="",
            )

    with gr.Row():
        with gr.Column():
            text_to_trans = gr.Textbox(
                lines=10, placeholder="Enter text to translate", label=""
            )

        with gr.Column():
            output_text = gr.Textbox(lines=10, label="")

    translate_button = gr.Button("Translate")
    translate_button.click(
        generate_text, [text_to_trans, from_lang, to_lang], output_text
    )


if __name__ == "__main__":
    demo.queue(max_size=20).launch()