File size: 3,561 Bytes
998b0a8
 
a411157
998b0a8
 
b6e525e
998b0a8
192eae5
dfc584d
192eae5
3d5b038
f2d1f01
b6e525e
dfc584d
b6e525e
dfc584d
 
 
 
 
 
 
 
 
ef4866e
dfc584d
 
 
 
 
 
 
 
ef4866e
3d5b038
 
 
dfc584d
 
 
3d5b038
dfc584d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f63e352
dfc584d
 
 
4664330
dfc584d
4664330
 
3d5b038
9213095
998b0a8
 
 
dfc584d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
998b0a8
dfc584d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4664330
dfc584d
 
 
 
 
 
 
 
 
998b0a8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
from huggingface_hub import InferenceClient
import os

"""
Copied from inference in colab notebook
"""

from transformers import pipeline

# Load model and tokenizer globally to avoid reloading for every request
model_path = "Mat17892/t5small_enfr_opus"

# translator = pipeline("translation_xx_to_yy", model=model_path)

# def respond(
#     message: str,
#     history: list[tuple[str, str]],
#     system_message: str,
#     max_tokens: int,
#     temperature: float,
#     top_p: float,
# ):
#     message = "translate English to French:" + message

#     response = translator(message)[0]
#     yield response['translation_text']

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
import threading

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

def respond(
    message: str,
    system_message: str,
    max_tokens: int = 128,
    temperature: float = 1.0,
    top_p: float = 1.0,
):
    # Preprocess the input message
    input_text = system_message + " " + message
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids

    # Set up the streamer
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

    # Generate in a separate thread to avoid blocking
    generation_thread = threading.Thread(
        target=model.generate,
        kwargs={
            "input_ids": input_ids,
            "max_new_tokens": max_tokens,
            "do_sample": True,
            "temperature": temperature,
            "top_p": top_p,
            "streamer": streamer,
        },
    )
    generation_thread.start()

    # Stream the output progressively
    generated_text = ""
    for token in streamer:  # Append each token to the accumulated text
        generated_text += token
        yield generated_text


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""

# Define the interface
with gr.Blocks() as demo:
    gr.Markdown("# Google Translate-like Interface")

    with gr.Row():
        with gr.Column():
            source_textbox = gr.Textbox(
                placeholder="Enter text in English...",
                label="Source Text (English)",
                lines=5,
            )
        with gr.Column():
            translated_textbox = gr.Textbox(
                placeholder="Translation will appear here...",
                label="Translated Text (French)",
                lines=5,
                interactive=False,
            )

    translate_button = gr.Button("Translate")

    with gr.Accordion("Advanced Settings", open=False):
        system_message_input = gr.Textbox(
            value="translate English to French:",
            label="System message",
        )
        max_tokens_slider = gr.Slider(
            minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
        )
        temperature_slider = gr.Slider(
            minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"
        )
        top_p_slider = gr.Slider(
            minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
        )

    # Define functionality
    translate_button.click(
        respond,
        inputs=[
            source_textbox,
            system_message_input,
            max_tokens_slider,
            temperature_slider,
            top_p_slider,
        ],
        outputs=translated_textbox,
    )

if __name__ == "__main__":
    demo.launch()