File size: 4,594 Bytes
87d8688
3e11881
b8b5a68
3e11881
 
87d8688
b8b5a68
3e11881
 
b8b5a68
3e11881
b8b5a68
 
 
1b19c76
b8b5a68
87d8688
1b19c76
 
2873048
1b19c76
3e11881
 
1b19c76
 
 
 
b96433e
2cbd116
b96433e
 
 
 
2cbd116
 
 
 
 
b96433e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfdc52b
b96433e
 
 
 
 
 
4de4cdd
3e11881
1b19c76
3e11881
 
b96433e
1b19c76
 
 
 
 
 
 
 
3e11881
 
1b19c76
 
 
 
 
 
 
b336e2d
1b19c76
 
 
3e11881
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datetime import datetime

model_id = "BSC-LT/salamandraTA-7b-instruct"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

languages = sorted([ 'Aragonese', 'Asturian', 'Basque', 'Bulgarian', 'Catalan', 'Valencian', 'Croatian', 'Czech', 'Danish', 'Dutch', 'English', 'Estonian',
             'Finnish', 'French', 'Galician', 'German', 'Greek', 'Hungarian', 'Irish', 'Italian', 'Latvian', 'Lithuanian', 'Maltese', 'Norwegian Bokmål',
             'Norwegian Nynorsk', 'Occitan', 'Aranese', 'Polish', 'Portuguese', 'Romanian', 'Russian', 'Serbian_Cyrillic', 'Slovak', 'Slovenian', 'Spanish', 'Swedish',
             'Ukrainian', 'Welsh' ])

@spaces.GPU(duration=120)
def generate_output(task, source, target, input_text, mt_text=None):
    date_string = datetime.today().strftime('%Y-%m-%d')


    sentences = input_text.split('\n')
    #sentences = [s for s in sentences if len(s.strip()) > 0]
    generated_text = []

    for sentence in sentences:
        sentence = sentence.strip()

        if len(sentence) == 0:
            # Preserve empty lines
            generated_text.append('')
            continue
        
        if task == "Translation":
            prompt = f"Translate the following text from {source} into {target}.\n{source}: {sentence.strip()} \n{target}:"
        elif task == "Post-editing":
            if not mt_text:
                return "Please provide machine translation (MT) for post-editing.", ""
            prompt = f"Please fix any mistakes in the following {source}-{target} machine translation or keep it unedited if it's correct.\nSource: {sentence.strip()} \nMT: {mt_text.strip()} \nCorrected:"
        elif task == "Grammar checker":
            prompt = f"Please fix any mistakes in the following {source} sentence or keep it unedited if it's correct.\nSentence: {sentence.strip()} \nCorrected:"
    
        messages = [{"role": "user", "content": prompt}]
        final_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            date_string=date_string
        )
    
        inputs = tokenizer(final_prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
        input_length = inputs.input_ids.shape[1]
    
        output = model.generate(
            input_ids=inputs.input_ids,
            max_new_tokens=4000,
            early_stopping=True,
            num_beams=1
        )
    
        decoded = tokenizer.decode(output[0, input_length:], skip_special_tokens=True).strip()
        generated_text.append(decoded)
        
    return '\n'.join(generated_text), ""

with gr.Blocks() as demo:
    gr.Markdown("# 🦎 SalamandraTA 7B - Multitask Demo")

    with gr.Row():
        task_selector = gr.Radio(["Translation", "Post-editing", "Grammar checker"], value="Translation", label="Select Task")

    with gr.Row():
        source_lang = gr.Dropdown(choices=languages, value="Catalan", label="Source Language")
        target_lang = gr.Dropdown(choices=languages, value="English", label="Target Language")

    input_textbox = gr.Textbox(lines=6, placeholder="Enter source text or token list here", label="Input Text")
    mt_textbox = gr.Textbox(lines=4, placeholder="(Only for Post-editing) Enter machine translation", label="Machine Translation (optional)")
    output_textbox = gr.Textbox(lines=6, label="Output")

    info_label = gr.HTML("")
    translate_btn = gr.Button("Generate")
    translate_btn.click(generate_output, inputs=[task_selector, source_lang, target_lang, input_textbox, mt_textbox], outputs=[output_textbox, info_label])

    gr.Examples(
        examples=[
            ["Translation", "Catalan", "Galician", "Als antics egipcis del període de l'Imperi Nou els fascinaven els monuments dels seus predecessors, que llavors tenien més de mil anys.", ""],
            ["Post-editing", "Catalan", "English", "Rafael Nadal i Maria Magdalena van inspirar a una generació sencera.", "Rafael Christmas and Maria the Muffin inspired an entire generation each in their own way."],
            ["Grammar checker", "Catalan", "Catalan", "Entonses, el meu jefe m’ha dit que he de treballar els fins de setmana.", ""],
        ],
        inputs=[task_selector, source_lang, target_lang, input_textbox, mt_textbox]
    )

if __name__ == "__main__":
    demo.launch()