Sshubam commited on
Commit
1ca4f47
Β·
verified Β·
1 Parent(s): c42a43b

update app

Browse files
Files changed (1) hide show
  1. app.py +191 -92
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import torch
3
  import spaces
 
4
  import gradio as gr
5
  from threading import Thread
6
  from collections.abc import Iterator
@@ -9,39 +10,126 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
9
  MAX_MAX_NEW_TOKENS = 4096
10
  MAX_INPUT_TOKEN_LENGTH = 4096
11
  DEFAULT_MAX_NEW_TOKENS = 2048
12
- HF_TOKEN = os.environ['HF_TOKEN']
13
 
14
  model_id = "ai4bharat/IndicTrans3-beta"
15
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
 
 
16
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
17
 
18
- LANGUAGES = {
19
- "Hindi": "hin_Deva",
20
- "Bengali": "ben_Beng",
21
- "Telugu": "tel_Telu",
22
- "Marathi": "mar_Deva",
23
- "Tamil": "tam_Taml",
24
- "Urdu": "urd_Arab",
25
- "Gujarati": "guj_Gujr",
26
- "Kannada": "kan_Knda",
27
- "Odia": "ori_Orya",
28
- "Malayalam": "mal_Mlym",
29
- "Punjabi": "pan_Guru",
30
- "Assamese": "asm_Beng",
31
- "Maithili": "mai_Mith",
32
- "Santali": "sat_Olck",
33
- "Kashmiri": "kas_Arab",
34
- "Nepali": "nep_Deva",
35
- "Sindhi": "snd_Arab",
36
- "Konkani": "kok_Deva",
37
- "Dogri": "dgo_Deva",
38
- "Manipuri": "mni_Beng",
39
- "Bodo": "brx_Deva"
40
- }
 
 
41
 
42
  def format_message_for_translation(message, target_lang):
43
  return f"Translate the following text to {target_lang}: {message}"
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @spaces.GPU
46
  def translate_message(
47
  message: str,
@@ -54,18 +142,24 @@ def translate_message(
54
  repetition_penalty: float = 1.2,
55
  ) -> Iterator[str]:
56
  conversation = []
57
-
58
  translation_request = format_message_for_translation(message, target_language)
59
- print(f"Translation request: {translation_request}")
60
  conversation.append({"role": "user", "content": translation_request})
61
 
62
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
 
 
63
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
64
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
65
- gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
 
66
  input_ids = input_ids.to(model.device)
67
 
68
- streamer = TextIteratorStreamer(tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True)
 
 
69
  generate_kwargs = dict(
70
  {"input_ids": input_ids},
71
  streamer=streamer,
@@ -85,17 +179,8 @@ def translate_message(
85
  outputs.append(text)
86
  yield "".join(outputs)
87
 
88
- def store_feedback(rating, feedback_text):
89
- if not rating:
90
- gr.Warning("Please select a rating before submitting feedback.", duration=5)
91
- return None
92
-
93
- if not feedback_text or feedback_text.strip() == "":
94
- gr.Warning("Please provide some feedback before submitting.", duration=5)
95
- return None
96
-
97
- gr.Info("Feedback submitted successfully!")
98
- return "Thank you for your feedback!"
99
 
100
  css = """
101
  # body {
@@ -140,58 +225,62 @@ Built for <b>linguistic diversity and accessibility</b>, IndicTrans3 is a major
140
 
141
  with gr.Blocks(css=css) as demo:
142
  with gr.Column(elem_classes="container"):
143
- gr.Markdown("# 🌏 IndicTrans3-beta πŸš€: Multilingual Translation for 22 Indic Languages </center>")
 
 
144
  gr.Markdown(DESCRIPTION)
145
-
146
  target_language = gr.Dropdown(
147
- list(LANGUAGES.keys()),
148
  value="Hindi",
149
  label="Which language would you like to translate to?",
150
- elem_id="language-dropdown"
151
  )
152
-
153
  chatbot = gr.Chatbot(height=400, elem_id="chatbot")
154
-
155
  with gr.Row():
156
  msg = gr.Textbox(
157
  placeholder="Enter text to translate...",
158
  show_label=False,
159
  container=False,
160
- scale=9
161
  )
162
  submit_btn = gr.Button("Translate", scale=1)
163
-
164
  gr.Examples(
165
  examples=[
166
  "The Taj Mahal stands majestically along the banks of river Yamuna, a timeless symbol of eternal love.",
167
  "Kumbh Mela is the world's largest gathering of people, where millions of pilgrims bathe in sacred rivers for spiritual purification.",
168
  "India's classical dance forms like Bharatanatyam, Kathak, and Odissi beautifully blend rhythm, expression, and storytelling.",
169
  "Ayurveda, the ancient Indian medical system, focuses on holistic wellness through natural herbs and balanced living.",
170
- "During Diwali, homes across India are decorated with oil lamps, colorful rangoli patterns, and twinkling lights to celebrate the victory of light over darkness."
171
  ],
172
- inputs=msg
173
  )
174
-
175
-
176
  with gr.Accordion("Provide Feedback", open=True):
177
  gr.Markdown("## Rate Translation & Provide Feedback πŸ“")
178
- gr.Markdown("Help us improve the translation quality by providing your feedback.")
 
 
179
  with gr.Row():
180
  rating = gr.Radio(
181
- ["1", "2", "3", "4", "5"],
182
- label="Translation Rating (1-5)"
183
  )
184
-
185
  feedback_text = gr.Textbox(
186
  placeholder="Share your feedback about the translation...",
187
  label="Feedback",
188
- lines=3
189
  )
190
-
191
  feedback_submit = gr.Button("Submit Feedback")
192
  feedback_result = gr.Textbox(label="", visible=False)
193
 
194
- with gr.Accordion("Advanced Options", open=False, elem_classes="advanced-options"):
 
 
195
  max_new_tokens = gr.Slider(
196
  label="Max new tokens",
197
  minimum=1,
@@ -227,56 +316,66 @@ with gr.Blocks(css=css) as demo:
227
  step=0.05,
228
  value=1.0,
229
  )
230
-
231
  chat_state = gr.State([])
232
-
233
  def user(user_message, history, target_lang):
234
  return "", history + [[user_message, None]]
235
-
236
- def bot(history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty):
 
 
237
  user_message = history[-1][0]
238
  history[-1][1] = ""
239
-
240
  for chunk in translate_message(
241
- user_message,
242
- history[:-1],
243
- target_lang,
244
- max_tokens,
245
- temp,
246
- top_p_val,
247
- top_k_val,
248
- rep_penalty
249
  ):
250
  history[-1][1] = chunk
251
  yield history
252
-
253
  msg.submit(
254
- user,
255
- [msg, chatbot, target_language],
256
- [msg, chatbot],
257
- queue=False
258
  ).then(
259
  bot,
260
- [chatbot, target_language, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
261
- chatbot
 
 
 
 
 
 
 
 
262
  )
263
-
264
  submit_btn.click(
265
- user,
266
- [msg, chatbot, target_language],
267
- [msg, chatbot],
268
- queue=False
269
  ).then(
270
  bot,
271
- [chatbot, target_language, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
272
- chatbot
 
 
 
 
 
 
 
 
273
  )
274
-
275
  feedback_submit.click(
276
- fn=store_feedback,
277
- inputs=[rating, feedback_text],
278
- outputs=feedback_result
279
  )
280
-
281
  if __name__ == "__main__":
282
  demo.launch()
 
1
  import os
2
  import torch
3
  import spaces
4
+ import psycopg2
5
  import gradio as gr
6
  from threading import Thread
7
  from collections.abc import Iterator
 
10
  MAX_MAX_NEW_TOKENS = 4096
11
  MAX_INPUT_TOKEN_LENGTH = 4096
12
  DEFAULT_MAX_NEW_TOKENS = 2048
13
+ HF_TOKEN = os.environ["HF_TOKEN"]
14
 
15
  model_id = "ai4bharat/IndicTrans3-beta"
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN
18
+ )
19
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
20
 
21
+
22
+ LANGUAGES = [
23
+ "Hindi",
24
+ "Bengali",
25
+ "Telugu",
26
+ "Marathi",
27
+ "Tamil",
28
+ "Urdu",
29
+ "Gujarati",
30
+ "Kannada",
31
+ "Odia",
32
+ "Malayalam",
33
+ "Punjabi",
34
+ "Assamese",
35
+ "Maithili",
36
+ "Santali",
37
+ "Kashmiri",
38
+ "Nepali",
39
+ "Sindhi",
40
+ "Konkani",
41
+ "Dogri",
42
+ "Manipuri",
43
+ "Bodo",
44
+ ]
45
+
46
 
47
  def format_message_for_translation(message, target_lang):
48
  return f"Translate the following text to {target_lang}: {message}"
49
 
50
+
51
+ def store_feedback(rating, feedback_text, chat_history, tgt_lang):
52
+ try:
53
+
54
+ if not rating:
55
+ gr.Warning("Please select a rating before submitting feedback.", duration=5)
56
+ return None
57
+
58
+ if not feedback_text or feedback_text.strip() == "":
59
+ gr.Warning("Please provide some feedback before submitting.", duration=5)
60
+ return None
61
+
62
+ if not chat_history:
63
+ gr.Warning(
64
+ "Please provide the input text before submitting feedback.", duration=5
65
+ )
66
+ return None
67
+
68
+ if len(chat_history[0]) < 2:
69
+ gr.Warning(
70
+ "Please translate the input text before submitting feedback.",
71
+ duration=5,
72
+ )
73
+ return None
74
+
75
+ conn = psycopg2.connect(
76
+ host=os.getenv("DB_HOST"),
77
+ database=os.getenv("DB_NAME"),
78
+ user=os.getenv("DB_USER"),
79
+ password=os.getenv("DB_PASSWORD"),
80
+ port=os.getenv("DB_PORT"),
81
+ )
82
+
83
+ cursor = conn.cursor()
84
+
85
+ insert_query = """
86
+ INSERT INTO feedback
87
+ (tgt_lang, rating, feedback_txt, chat_history)
88
+ VALUES (%s, %s, %s, %s)
89
+ """
90
+
91
+ cursor.execute(
92
+ insert_query, (tgt_lang, int(rating), feedback_text, chat_history)
93
+ )
94
+
95
+ conn.commit()
96
+
97
+ cursor.close()
98
+ conn.close()
99
+
100
+ gr.Info("Thank you for your feedback! πŸ™", duration=5)
101
+
102
+ except:
103
+ gr.Error(
104
+ "An error occurred while storing feedback. Please try again later.",
105
+ duration=5,
106
+ )
107
+
108
+
109
+ def store_output(tgt_lang, input_text, output_text):
110
+
111
+ conn = psycopg2.connect(
112
+ host=os.getenv("DB_HOST"),
113
+ database=os.getenv("DB_NAME"),
114
+ user=os.getenv("DB_USER"),
115
+ password=os.getenv("DB_PASSWORD"),
116
+ port=os.getenv("DB_PORT"),
117
+ )
118
+
119
+ cursor = conn.cursor()
120
+
121
+ insert_query = """
122
+ INSERT INTO translation
123
+ (input_txt, output_txt, tgt_lang)
124
+ VALUES (%s, %s, %s)
125
+ """
126
+
127
+ cursor.execute(insert_query, (input_text, output_text, tgt_lang))
128
+
129
+ conn.commit()
130
+ cursor.close()
131
+
132
+
133
  @spaces.GPU
134
  def translate_message(
135
  message: str,
 
142
  repetition_penalty: float = 1.2,
143
  ) -> Iterator[str]:
144
  conversation = []
145
+
146
  translation_request = format_message_for_translation(message, target_language)
147
+
148
  conversation.append({"role": "user", "content": translation_request})
149
 
150
+ input_ids = tokenizer.apply_chat_template(
151
+ conversation, return_tensors="pt", add_generation_prompt=True
152
+ )
153
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
154
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
155
+ gr.Warning(
156
+ f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
157
+ )
158
  input_ids = input_ids.to(model.device)
159
 
160
+ streamer = TextIteratorStreamer(
161
+ tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True
162
+ )
163
  generate_kwargs = dict(
164
  {"input_ids": input_ids},
165
  streamer=streamer,
 
179
  outputs.append(text)
180
  yield "".join(outputs)
181
 
182
+ store_output(target_language, message, "".join(outputs))
183
+
 
 
 
 
 
 
 
 
 
184
 
185
  css = """
186
  # body {
 
225
 
226
  with gr.Blocks(css=css) as demo:
227
  with gr.Column(elem_classes="container"):
228
+ gr.Markdown(
229
+ "# 🌏 IndicTrans3-beta πŸš€: Multilingual Translation for 22 Indic Languages </center>"
230
+ )
231
  gr.Markdown(DESCRIPTION)
232
+
233
  target_language = gr.Dropdown(
234
+ LANGUAGES,
235
  value="Hindi",
236
  label="Which language would you like to translate to?",
237
+ elem_id="language-dropdown",
238
  )
239
+
240
  chatbot = gr.Chatbot(height=400, elem_id="chatbot")
241
+
242
  with gr.Row():
243
  msg = gr.Textbox(
244
  placeholder="Enter text to translate...",
245
  show_label=False,
246
  container=False,
247
+ scale=9,
248
  )
249
  submit_btn = gr.Button("Translate", scale=1)
250
+
251
  gr.Examples(
252
  examples=[
253
  "The Taj Mahal stands majestically along the banks of river Yamuna, a timeless symbol of eternal love.",
254
  "Kumbh Mela is the world's largest gathering of people, where millions of pilgrims bathe in sacred rivers for spiritual purification.",
255
  "India's classical dance forms like Bharatanatyam, Kathak, and Odissi beautifully blend rhythm, expression, and storytelling.",
256
  "Ayurveda, the ancient Indian medical system, focuses on holistic wellness through natural herbs and balanced living.",
257
+ "During Diwali, homes across India are decorated with oil lamps, colorful rangoli patterns, and twinkling lights to celebrate the victory of light over darkness.",
258
  ],
259
+ inputs=msg,
260
  )
261
+
 
262
  with gr.Accordion("Provide Feedback", open=True):
263
  gr.Markdown("## Rate Translation & Provide Feedback πŸ“")
264
+ gr.Markdown(
265
+ "Help us improve the translation quality by providing your feedback."
266
+ )
267
  with gr.Row():
268
  rating = gr.Radio(
269
+ ["1", "2", "3", "4", "5"], label="Translation Rating (1-5)"
 
270
  )
271
+
272
  feedback_text = gr.Textbox(
273
  placeholder="Share your feedback about the translation...",
274
  label="Feedback",
275
+ lines=3,
276
  )
277
+
278
  feedback_submit = gr.Button("Submit Feedback")
279
  feedback_result = gr.Textbox(label="", visible=False)
280
 
281
+ with gr.Accordion(
282
+ "Advanced Options", open=False, elem_classes="advanced-options"
283
+ ):
284
  max_new_tokens = gr.Slider(
285
  label="Max new tokens",
286
  minimum=1,
 
316
  step=0.05,
317
  value=1.0,
318
  )
319
+
320
  chat_state = gr.State([])
321
+
322
  def user(user_message, history, target_lang):
323
  return "", history + [[user_message, None]]
324
+
325
+ def bot(
326
+ history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty
327
+ ):
328
  user_message = history[-1][0]
329
  history[-1][1] = ""
330
+
331
  for chunk in translate_message(
332
+ user_message,
333
+ history[:-1],
334
+ target_lang,
335
+ max_tokens,
336
+ temp,
337
+ top_p_val,
338
+ top_k_val,
339
+ rep_penalty,
340
  ):
341
  history[-1][1] = chunk
342
  yield history
343
+
344
  msg.submit(
345
+ user, [msg, chatbot, target_language], [msg, chatbot], queue=False
 
 
 
346
  ).then(
347
  bot,
348
+ [
349
+ chatbot,
350
+ target_language,
351
+ max_new_tokens,
352
+ temperature,
353
+ top_p,
354
+ top_k,
355
+ repetition_penalty,
356
+ ],
357
+ chatbot,
358
  )
359
+
360
  submit_btn.click(
361
+ user, [msg, chatbot, target_language], [msg, chatbot], queue=False
 
 
 
362
  ).then(
363
  bot,
364
+ [
365
+ chatbot,
366
+ target_language,
367
+ max_new_tokens,
368
+ temperature,
369
+ top_p,
370
+ top_k,
371
+ repetition_penalty,
372
+ ],
373
+ chatbot,
374
  )
375
+
376
  feedback_submit.click(
377
+ fn=store_feedback,
378
+ inputs=[rating, feedback_text, chatbot, target_language],
 
379
  )
 
380
  if __name__ == "__main__":
381
  demo.launch()