javi8979 commited on
Commit
1b19c76
verified
1 Parent(s): a916a98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -51
app.py CHANGED
@@ -12,67 +12,89 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_id,
14
  device_map="auto",
15
- torch_dtype=torch.bfloat16 # Usa bf16 como en el ejemplo original
16
  )
17
 
18
- languages = [ 'Aragonese', 'Asturian', 'Basque', 'Bulgarian', 'Catalan', 'Valencian', 'Croatian', 'Czech', 'Danish', 'Dutch', 'English', 'Estonian',
19
- 'Finnish', 'French', 'Galician', 'German', 'Greek', 'Hungarian', 'Irish', 'Italian', 'Latvian', 'Lithuanian', 'Maltese', 'Norwegian Bokm氓l',
20
- 'Norwegian Nynorsk', 'Occitan', 'Aranese', 'Polish', 'Portuguese', 'Romanian', 'Russian', 'Serbian', 'Slovak', 'Slovenian', 'Spanish', 'Swedish',
21
- 'Ukrainian', 'Welsh' ]
22
-
23
- example_sentence = ["Ahir se'n va anar, va agafar les seves coses i es va posar a navegar."]
24
 
25
  @spaces.GPU(duration=120)
26
- def translate(input_text, source, target):
27
- sentences = [s for s in input_text.strip().split('\n') if s.strip()]
28
- translated_sentences = []
29
-
30
- for sentence in sentences:
31
- prompt_text = f"Translate the following text from {source} into {target}.\n{source}: {sentence} \n{target}:"
32
- messages = [{"role": "user", "content": prompt_text}]
33
- date_string = datetime.today().strftime('%Y-%m-%d')
34
-
35
- prompt = tokenizer.apply_chat_template(
36
- messages,
37
- tokenize=False,
38
- add_generation_prompt=True,
39
- date_string=date_string
40
- )
41
-
42
- inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
43
- input_length = inputs.input_ids.shape[1]
44
-
45
- output = model.generate(
46
- input_ids=inputs.input_ids,
47
- max_new_tokens=400,
48
- early_stopping=True,
49
- num_beams=5
50
- )
51
-
52
- decoded = tokenizer.decode(output[0, input_length:], skip_special_tokens=True).strip()
53
- translated_sentences.append(decoded)
54
-
55
- return '\n'.join(translated_sentences), ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  with gr.Blocks() as demo:
58
- gr.HTML("""<html>
59
- <head><style>h1 { text-align: center; }</style></head>
60
- <body><h1>SalamandraTA 7B Translate</h1></body>
61
- </html>""")
62
 
63
  with gr.Row():
64
- with gr.Column():
65
- source_language_dropdown = gr.Dropdown(choices=languages, value="Catalan", label="Source Language")
66
- input_textbox = gr.Textbox(lines=5, placeholder="Enter text to translate", label="Input Text")
67
- with gr.Column():
68
- target_language_dropdown = gr.Dropdown(choices=languages, value="English", label="Target Language")
69
- translated_textbox = gr.Textbox(lines=5, placeholder="", label="Translated Text")
 
 
 
70
 
71
  info_label = gr.HTML("")
72
- btn = gr.Button("Translate")
73
- btn.click(translate, inputs=[input_textbox, source_language_dropdown, target_language_dropdown],
74
- outputs=[translated_textbox, info_label])
75
- gr.Examples(example_sentence, inputs=[input_textbox])
 
 
 
 
 
 
 
 
 
76
 
77
  if __name__ == "__main__":
78
  demo.launch()
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_id,
14
  device_map="auto",
15
+ torch_dtype=torch.bfloat16
16
  )
17
 
18
+ languages = sorted([ 'Aragonese', 'Asturian', 'Basque', 'Bulgarian', 'Catalan', 'Valencian', 'Croatian', 'Czech', 'Danish', 'Dutch', 'English', 'Estonian',
19
+ 'Finnish', 'French', 'Galician', 'German', 'Greek', 'Hungarian', 'Irish', 'Italian', 'Latvian', 'Lithuanian', 'Maltese', 'Norwegian Bokm氓l',
20
+ 'Norwegian Nynorsk', 'Occitan', 'Aranese', 'Polish', 'Portuguese', 'Romanian', 'Russian', 'Serbian', 'Slovak', 'Slovenian', 'Spanish', 'Swedish',
21
+ 'Ukrainian', 'Welsh' ])
 
 
22
 
23
  @spaces.GPU(duration=120)
24
+ def generate_output(task, source, target, input_text, mt_text=None):
25
+ date_string = datetime.today().strftime('%Y-%m-%d')
26
+
27
+ if task == "Translation":
28
+ prompt = f"Translate the following text from {source} into {target}.\n{source}: {input_text.strip()} \n{target}:"
29
+ elif task == "Post-editing":
30
+ if not mt_text:
31
+ return "Please provide machine translation (MT) for post-editing.", ""
32
+ prompt = f"Please fix any mistakes in the following {source}-{target} machine translation or keep it unedited if it's correct.\nSource: {input_text.strip()} \nMT: {mt_text.strip()} \nCorrected:"
33
+ elif task == "Document translation":
34
+ prompt = f"Please translate this text from {source} into {target}.\n{source}: {input_text.strip()}\n{target}:"
35
+ elif task == "Grammar checker":
36
+ prompt = f"Please fix any mistakes in the following {source} sentence or keep it unedited if it's correct.\nSentence: {input_text.strip()} \nCorrected:"
37
+ elif task == "Named-entity recognition":
38
+ prompt = """Analyse the following tokenized text and mark the tokens containing named entities.
39
+ Use the following annotation guidelines with these tags for named entities:
40
+ - ORG (Refers to named groups or organizations)
41
+ - PER (Refers to individual people or named groups of people)
42
+ - LOC (Refers to physical places or natural landmarks)
43
+ - MISC (Refers to entities that don't fit into standard categories).
44
+ Prepend B- to the first token of a given entity and I- to the remaining ones if they exist.
45
+ If a token is not a named entity, label it as O.
46
+ Input: """ + str(input_text.strip()) + "\nMarked:"
47
+
48
+ messages = [{"role": "user", "content": prompt}]
49
+ final_prompt = tokenizer.apply_chat_template(
50
+ messages,
51
+ tokenize=False,
52
+ add_generation_prompt=True,
53
+ date_string=date_string
54
+ )
55
+
56
+ inputs = tokenizer(final_prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
57
+ input_length = inputs.input_ids.shape[1]
58
+
59
+ output = model.generate(
60
+ input_ids=inputs.input_ids,
61
+ max_new_tokens=512,
62
+ early_stopping=True,
63
+ num_beams=5
64
+ )
65
+
66
+ decoded = tokenizer.decode(output[0, input_length:], skip_special_tokens=True).strip()
67
+ return decoded, ""
68
 
69
  with gr.Blocks() as demo:
70
+ gr.Markdown("# 馃 SalamandraTA 7B - Multitask Demo")
71
+ gr.Markdown("Explore the translation, grammar correction, NER and post-editing capabilities of the SalamandraTA 7B model.")
 
 
72
 
73
  with gr.Row():
74
+ task_selector = gr.Radio(["Translation", "Document translation", "Post-editing", "Grammar checker", "Named-entity recognition"], value="Translation", label="Select Task")
75
+
76
+ with gr.Row():
77
+ source_lang = gr.Dropdown(choices=languages, value="Catalan", label="Source Language")
78
+ target_lang = gr.Dropdown(choices=languages, value="English", label="Target Language")
79
+
80
+ input_textbox = gr.Textbox(lines=6, placeholder="Enter source text or token list here", label="Input Text")
81
+ mt_textbox = gr.Textbox(lines=4, placeholder="(Only for Post-editing) Enter machine translation", label="Machine Translation (optional)")
82
+ output_textbox = gr.Textbox(lines=6, label="Output")
83
 
84
  info_label = gr.HTML("")
85
+ translate_btn = gr.Button("Generate")
86
+ translate_btn.click(generate_output, inputs=[task_selector, source_lang, target_lang, input_textbox, mt_textbox], outputs=[output_textbox, info_label])
87
+
88
+ gr.Examples(
89
+ examples=[
90
+ ["Translation", "Catalan", "Galician", "Als antics egipcis del per铆ode de l'Imperi Nou els fascinaven els monuments dels seus predecessors, que llavors tenien m茅s de mil anys.", ""],
91
+ ["Post-editing", "Catalan", "English", "Rafael Nadal i Maria Magdalena van inspirar a una generaci贸 sencera.", "Rafael Christmas and Maria the Muffin inspired an entire generation each in their own way."],
92
+ ["Grammar checker", "Catalan", "", "Entonses, el meu jefe m鈥檋a dit que he de treballar els fins de setmana.", ""],
93
+ ["Named-entity recognition", "", "", "['La', 'defensa', 'del', 'antiguo', 'responsable', 'de', 'la', 'RFEF', 'confirma', 'que', 'interpondr谩', 'un', 'recurso.']", ""],
94
+ ["Document translation", "English", "Asturian", "President Donald Trump, who campaigned on promises to crack down on illegal immigration, has raised alarms in the U.S. dairy industry...", ""]
95
+ ],
96
+ inputs=[task_selector, source_lang, target_lang, input_textbox, mt_textbox]
97
+ )
98
 
99
  if __name__ == "__main__":
100
  demo.launch()