Sshubam commited on
Commit
0a59b92
Β·
verified Β·
1 Parent(s): 9b8148d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -145
app.py CHANGED
@@ -1,50 +1,20 @@
1
  import os
2
  import torch
3
  import spaces
4
- from collections.abc import Iterator
5
- from threading import Thread
6
  import gradio as gr
 
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
 
9
  MAX_MAX_NEW_TOKENS = 4096
10
- DEFAULT_MAX_NEW_TOKENS = 2048
11
  MAX_INPUT_TOKEN_LENGTH = 4096
12
-
13
  HF_TOKEN = os.environ['HF_TOKEN']
14
 
15
- DESCRIPTION = """\
16
- ## 🌏 IndicTrans3-beta πŸš€: Multilingual Translation for 22 Indic Languages
17
-
18
- IndicTrans3 is the latest state-of-the-art (SOTA) translation model from AI4Bharat, designed to handle translations across **22 Indic languages** with high accuracy. It supports **document-level machine translation (MT)** and is built to match the performance of other leading SOTA models.
19
-
20
- πŸ“’ **Training data will be released soon!**
21
-
22
- ### πŸ”Ή Features
23
- βœ… Supports **22 Indic languages**
24
- βœ… Enables **document-level translation**
25
- βœ… Achieves **SOTA performance** in Indic MT
26
- βœ… Optimized for **real-world applications**
27
-
28
- ### πŸš€ Try It Out!
29
- 1️⃣ Enter text in any supported language
30
- 2️⃣ Select the target language
31
- 3️⃣ Click **Translate** and get high-quality results!
32
-
33
- Built for **linguistic diversity and accessibility**, IndicTrans3 is a major step forward in **Indic language AI**.
34
-
35
- πŸ’‘ **Source:** AI4Bharat | Powered by Hugging Face
36
- """
37
-
38
- # if not torch.cuda.is_available():
39
- # DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"
40
-
41
-
42
- # if torch.cuda.is_available():
43
  model_id = "ai4bharat/IndicTrans3-beta"
44
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", offload_folder="offload", token=HF_TOKEN)
45
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
46
 
47
-
48
  LANGUAGES = {
49
  "Hindi": "hin_Deva",
50
  "Bengali": "ben_Beng",
@@ -69,59 +39,33 @@ LANGUAGES = {
69
  "Bodo": "brx_Deva"
70
  }
71
 
72
- @spaces.GPU
73
- def generate_for_examples(
74
- tgt_lang: str,
75
- message: str,
76
- max_new_tokens: int = 1024,
77
- temperature: float = 0.6,
78
- top_p: float = 0.9,
79
- top_k: int = 50,
80
- repetition_penalty: float = 1.2,
81
- ) -> str:
82
- conversation = []
83
- conversation.append({"role": "user", "content": f"Translate the following text to {tgt_lang}: {message}"})
84
-
85
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
86
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
87
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
88
- input_ids = input_ids.to(model.device)
89
-
90
- outputs = model.generate(
91
- input_ids=input_ids,
92
- max_new_tokens=max_new_tokens,
93
- do_sample=True,
94
- top_p=top_p,
95
- top_k=top_k,
96
- temperature=temperature,
97
- num_beams=1,
98
- repetition_penalty=repetition_penalty,
99
- )
100
-
101
- return tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
102
-
103
 
104
  @spaces.GPU
105
- def generate(
106
- tgt_lang: str,
107
  message: str,
 
 
108
  max_new_tokens: int = 1024,
109
  temperature: float = 0.6,
110
  top_p: float = 0.9,
111
  top_k: int = 50,
112
  repetition_penalty: float = 1.2,
113
  ) -> Iterator[str]:
114
-
115
  conversation = []
116
- conversation.append({"role": "user", "content": f"Translate the following text to {tgt_lang}: {message}"})
 
 
 
117
 
118
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
119
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
120
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
121
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
122
  input_ids = input_ids.to(model.device)
123
 
124
- streamer = TextIteratorStreamer(tokenizer, timeout=180.0, skip_prompt=True, skip_special_tokens=True)
125
  generate_kwargs = dict(
126
  {"input_ids": input_ids},
127
  streamer=streamer,
@@ -154,84 +98,84 @@ def store_feedback(rating, feedback_text):
154
  return "Thank you for your feedback!"
155
 
156
  css = """
157
- #col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
158
- #header {text-align: left;}
159
- .message { font-size: 1.2em; }
160
- #feedback-section { margin-top: 30px; border-top: 1px solid #ddd; padding-top: 20px; }
161
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
164
- gr.Markdown(DESCRIPTION, elem_id="header")
165
- gr.Markdown("Translate text between multiple Indic languages using the latest IndicTrans3 model from AI4Bharat. This model is trained on the --- dataset and supports translation to 22 Indic languages. Setting a state-of-the-art benchmark on multiple translation tasks, IndicTrans3 is a powerful model that can handle complex translation tasks with ease.", elem_id="description")
166
-
167
- with gr.Column(elem_id="col-container"):
168
- with gr.Row():
169
- with gr.Column():
170
-
171
- text_input = gr.Textbox(
172
- placeholder="Enter text to translate...",
173
- label="Input text",
174
- lines=10,
175
- max_lines=100,
176
- elem_id="input-text"
177
- )
178
-
179
- with gr.Column():
180
- tgt_lang = gr.Dropdown(
181
- list(LANGUAGES.keys()),
182
- value="Hindi",
183
- label="Translate To",
184
- elem_id="translate-to"
185
- )
186
-
187
- text_output = gr.Textbox(
188
- label="",
189
- lines=10,
190
- max_lines=100,
191
- elem_id="output-text"
192
- )
193
 
194
- btn_submit = gr.Button("Translate")
195
- btn_submit.click(
196
- fn=generate,
197
- inputs=[
198
- tgt_lang,
199
- text_input,
200
- gr.Number(value=4096, visible=False),
201
- gr.Number(value=0.1, visible=False),
202
- gr.Number(value=0.9, visible=False),
203
- gr.Number(value=50, visible=False),
204
- gr.Number(value=1.0, visible=False)
205
- ],
206
- outputs=text_output
207
  )
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  gr.Examples(
210
  examples=[
211
- ["Telugu", "Hello, how are you today? I hope you're doing well."],
212
- ["Punjabi", "Hello, how are you today? I hope you're doing well."],
213
- ["Hindi", "Hello, how are you today? I hope you're doing well."],
214
- ["Marathi", "Hello, how are you today? I hope you're doing well."],
215
- ["Malayalam", "Hello, how are you today? I hope you're doing well."]
216
  ],
217
- inputs=[
218
- tgt_lang,
219
- text_input,
220
- gr.Number(value=4096, visible=False),
221
- gr.Number(value=0.1, visible=False),
222
- gr.Number(value=0.9, visible=False),
223
- gr.Number(value=50, visible=False),
224
- gr.Number(value=1.0, visible=False)
225
- ],
226
- outputs=text_output,
227
- fn=generate_for_examples,
228
- cache_examples=True,
229
- examples_per_page=5
230
  )
231
 
232
- with gr.Column(elem_id="feedback-section"):
 
233
  gr.Markdown("## Rate Translation & Provide Feedback πŸ“")
234
- gr.Markdown("Help us improve the translation quality by providing your feedback and rating.")
235
  with gr.Row():
236
  rating = gr.Radio(
237
  ["1", "2", "3", "4", "5"],
@@ -246,11 +190,93 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
246
 
247
  feedback_submit = gr.Button("Submit Feedback")
248
  feedback_result = gr.Textbox(label="", visible=False)
249
-
250
- feedback_submit.click(
251
- fn=store_feedback,
252
- inputs=[rating, feedback_text],
253
- outputs=feedback_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- demo.launch()
 
 
1
  import os
2
  import torch
3
  import spaces
 
 
4
  import gradio as gr
5
+ from threading import Thread
6
+ from collections.abc import Iterator
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
 
9
  MAX_MAX_NEW_TOKENS = 4096
 
10
  MAX_INPUT_TOKEN_LENGTH = 4096
11
+ DEFAULT_MAX_NEW_TOKENS = 2048
12
  HF_TOKEN = os.environ['HF_TOKEN']
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  model_id = "ai4bharat/IndicTrans3-beta"
15
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
16
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
17
 
 
18
  LANGUAGES = {
19
  "Hindi": "hin_Deva",
20
  "Bengali": "ben_Beng",
 
39
  "Bodo": "brx_Deva"
40
  }
41
 
42
+ def format_message_for_translation(message, target_lang):
43
+ return f"Translate the following text to {target_lang}: {message}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  @spaces.GPU
46
+ def translate_message(
 
47
  message: str,
48
+ chat_history: list[dict],
49
+ target_language: str = "Hindi",
50
  max_new_tokens: int = 1024,
51
  temperature: float = 0.6,
52
  top_p: float = 0.9,
53
  top_k: int = 50,
54
  repetition_penalty: float = 1.2,
55
  ) -> Iterator[str]:
 
56
  conversation = []
57
+
58
+ translation_request = format_message_for_translation(message, target_language)
59
+ print(f"Translation request: {translation_request}")
60
+ conversation.append({"role": "user", "content": translation_request})
61
 
62
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
63
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
64
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
65
+ gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
66
  input_ids = input_ids.to(model.device)
67
 
68
+ streamer = TextIteratorStreamer(tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True)
69
  generate_kwargs = dict(
70
  {"input_ids": input_ids},
71
  streamer=streamer,
 
98
  return "Thank you for your feedback!"
99
 
100
  css = """
101
+ body {
102
+ background-color: #f7f7f7;
103
+ }
104
+ .feedback-section {
105
+ margin-top: 30px;
106
+ border-top: 1px solid #ddd;
107
+ padding-top: 20px;
108
+ }
109
+ .container {
110
+ max-width: 90%;
111
+ margin: 0 auto;
112
+ }
113
+ .language-selector {
114
+ margin-bottom: 20px;
115
+ padding: 10px;
116
+ background-color: #ffffff;
117
+ border-radius: 8px;
118
+ box-shadow: 0 2px 5px rgba(0,0,0,0.1);
119
+ }
120
+ .advanced-options {
121
+ margin-top: 20px;
122
+ }
123
+ """
124
 
125
+ DESCRIPTION = """\
126
+ IndicTrans3 is the latest state-of-the-art (SOTA) translation model from AI4Bharat, designed to handle translations across <b>22 Indic languages</b> with high accuracy. It supports <b>document-level machine translation (MT)</b> and is built to match the performance of other leading SOTA models. <br>
127
+ πŸ“’ <b>Training data will be released soon!</b>
128
+ <h3>πŸ”Ή Features</h3>
129
+ βœ… Supports <b>22 Indic languages</b>
130
+ βœ… Enables <b>document-level translation</b>
131
+ βœ… Achieves <b>SOTA performance</b> in Indic MT
132
+ βœ… Optimized for <b>real-world applications</b>
133
+ <h3>πŸš€ Try It Out!</h3>
134
+ 1️⃣ Enter text in any supported language
135
+ 2️⃣ Select the target language
136
+ 3️⃣ Click <b>Translate</b> and get high-quality results!
137
+ Built for <b>linguistic diversity and accessibility</b>, IndicTrans3 is a major step forward in <b>Indic language AI</b>.
138
+ πŸ’‘ <b>Source:</b> AI4Bharat | Powered by Hugging Face
139
+ """
140
+
141
+ with gr.Blocks(css=css) as demo:
142
+ with gr.Column(elem_classes="container"):
143
+ gr.Markdown("# 🌏 IndicTrans3-beta πŸš€: Multilingual Translation for 22 Indic Languages </center>")
144
+ gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
 
 
 
 
145
 
146
+ target_language = gr.Dropdown(
147
+ list(LANGUAGES.keys()),
148
+ value="Hindi",
149
+ label="Which language would you like to translate to?",
150
+ elem_id="language-dropdown"
 
 
 
 
 
 
 
 
151
  )
152
+
153
+ chatbot = gr.Chatbot(height=400, elem_id="chatbot")
154
+
155
+ with gr.Row():
156
+ msg = gr.Textbox(
157
+ placeholder="Enter text to translate...",
158
+ show_label=False,
159
+ container=False,
160
+ scale=9
161
+ )
162
+ submit_btn = gr.Button("Translate", scale=1)
163
 
164
  gr.Examples(
165
  examples=[
166
+ "The Taj Mahal stands majestically along the banks of river Yamuna, a timeless symbol of eternal love.",
167
+ "Kumbh Mela is the world's largest gathering of people, where millions of pilgrims bathe in sacred rivers for spiritual purification.",
168
+ "India's classical dance forms like Bharatanatyam, Kathak, and Odissi beautifully blend rhythm, expression, and storytelling.",
169
+ "Ayurveda, the ancient Indian medical system, focuses on holistic wellness through natural herbs and balanced living.",
170
+ "During Diwali, homes across India are decorated with oil lamps, colorful rangoli patterns, and twinkling lights to celebrate the victory of light over darkness."
171
  ],
172
+ inputs=msg
 
 
 
 
 
 
 
 
 
 
 
 
173
  )
174
 
175
+
176
+ with gr.Accordion("Provide Feedback", open=True):
177
  gr.Markdown("## Rate Translation & Provide Feedback πŸ“")
178
+ gr.Markdown("Help us improve the translation quality by providing your feedback.")
179
  with gr.Row():
180
  rating = gr.Radio(
181
  ["1", "2", "3", "4", "5"],
 
190
 
191
  feedback_submit = gr.Button("Submit Feedback")
192
  feedback_result = gr.Textbox(label="", visible=False)
193
+
194
+ with gr.Accordion("Advanced Options", open=False, elem_classes="advanced-options"):
195
+ max_new_tokens = gr.Slider(
196
+ label="Max new tokens",
197
+ minimum=1,
198
+ maximum=MAX_MAX_NEW_TOKENS,
199
+ step=1,
200
+ value=DEFAULT_MAX_NEW_TOKENS,
201
+ )
202
+ temperature = gr.Slider(
203
+ label="Temperature",
204
+ minimum=0.1,
205
+ maximum=1.0,
206
+ step=0.1,
207
+ value=0.1,
208
+ )
209
+ top_p = gr.Slider(
210
+ label="Top-p (nucleus sampling)",
211
+ minimum=0.05,
212
+ maximum=1.0,
213
+ step=0.05,
214
+ value=0.9,
215
+ )
216
+ top_k = gr.Slider(
217
+ label="Top-k",
218
+ minimum=1,
219
+ maximum=100,
220
+ step=1,
221
+ value=50,
222
+ )
223
+ repetition_penalty = gr.Slider(
224
+ label="Repetition penalty",
225
+ minimum=1.0,
226
+ maximum=2.0,
227
+ step=0.05,
228
+ value=1.0,
229
  )
230
+
231
+ chat_state = gr.State([])
232
+
233
+ def user(user_message, history, target_lang):
234
+ return "", history + [[user_message, None]]
235
+
236
+ def bot(history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty):
237
+ user_message = history[-1][0]
238
+ history[-1][1] = ""
239
+
240
+ for chunk in translate_message(
241
+ user_message,
242
+ history[:-1],
243
+ target_lang,
244
+ max_tokens,
245
+ temp,
246
+ top_p_val,
247
+ top_k_val,
248
+ rep_penalty
249
+ ):
250
+ history[-1][1] = chunk
251
+ yield history
252
+
253
+ msg.submit(
254
+ user,
255
+ [msg, chatbot, target_language],
256
+ [msg, chatbot],
257
+ queue=False
258
+ ).then(
259
+ bot,
260
+ [chatbot, target_language, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
261
+ chatbot
262
+ )
263
+
264
+ submit_btn.click(
265
+ user,
266
+ [msg, chatbot, target_language],
267
+ [msg, chatbot],
268
+ queue=False
269
+ ).then(
270
+ bot,
271
+ [chatbot, target_language, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
272
+ chatbot
273
+ )
274
+
275
+ feedback_submit.click(
276
+ fn=store_feedback,
277
+ inputs=[rating, feedback_text],
278
+ outputs=feedback_result
279
+ )
280
 
281
+ if __name__ == "__main__":
282
+ demo.launch()