Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from trl import AutoModelForCausalLMWithValueHead | |
# Set device and dtype | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch_dtype = torch.bfloat16 | |
# Load models only once at startup | |
print("Loading models...") | |
model_id = "meta-llama/Meta-Llama-3.1-8B" # Replace with your actual model ID | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
lm_model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
device_map="auto" | |
) | |
# Load the reward model | |
RM = AutoModelForCausalLMWithValueHead.from_pretrained( | |
'ray24724919/plan2align_rm', | |
torch_dtype=torch_dtype, | |
device_map="auto" | |
) | |
RM.eval() | |
print("Models loaded successfully!") | |
# Self-contained translation and evaluation functions | |
def translate(source_text, target_language="English"): | |
""" | |
Translate text from Chinese to the specified target language. | |
Args: | |
source_text (str): The Chinese text to translate | |
target_language (str): The target language for translation | |
Returns: | |
str: The translated text | |
""" | |
# Format the input as per the system prompt | |
messages = [ | |
{"role": "system", "content": "You are a helpful translator and only output the result."}, | |
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"} | |
] | |
# Format messages for the model | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
# Tokenize the input | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Generate translation | |
with torch.no_grad(): | |
outputs = lm_model.generate( | |
**inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode the generated text | |
translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip() | |
return translation | |
def evaluate_translation(source_text, translation, target_language="English"): | |
""" | |
Evaluate the quality of a translation using the reward model. | |
Args: | |
source_text (str): The original Chinese text | |
translation (str): The translated text | |
target_language (str): The target language of the translation | |
Returns: | |
float: The reward score | |
""" | |
messages = [ | |
{"role": "system", "content": "You are a helpful translator and only output the result."}, | |
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}, | |
{"role": "assistant", "content": translation} | |
] | |
# Format messages for the reward model | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False) | |
# Tokenize the input | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Get reward score | |
with torch.no_grad(): | |
outputs = RM(input_ids=inputs.input_ids) | |
reward_score = outputs.value.item() | |
return reward_score | |
# Combined function for the Gradio interface | |
def translate_text(source_text, target_language): | |
""" | |
Translate text and get reward score | |
Args: | |
source_text (str): The Chinese text to translate | |
target_language (str): The target language for translation | |
Returns: | |
tuple: (translation, reward_score) | |
""" | |
if not source_text.strip(): | |
return "Please enter some text to translate.", 0.0 | |
try: | |
translation = translate(source_text, target_language) | |
reward_score = evaluate_translation(source_text, translation, target_language) | |
return translation, float(reward_score) | |
except Exception as e: | |
return f"Error: {str(e)}", 0.0 | |
# Define available target languages | |
target_languages = [ | |
"English", "French", "Spanish", "German", "Italian", | |
"Portuguese", "Russian", "Japanese", "Korean", "Arabic" | |
] | |
# Create the Gradio interface | |
with gr.Blocks(title="Chinese Translation with Reward Scoring") as demo: | |
gr.Markdown("# Chinese to Any Language Translation") | |
gr.Markdown("This demo translates Chinese text to your chosen language and provides a quality score from our reward model.") | |
with gr.Row(): | |
with gr.Column(): | |
source_text = gr.Textbox( | |
label="Chinese Text", | |
placeholder="Enter Chinese text here...", | |
lines=5 | |
) | |
target_language = gr.Dropdown( | |
choices=target_languages, | |
value="English", | |
label="Target Language" | |
) | |
translate_button = gr.Button("Translate") | |
with gr.Column(): | |
translation_output = gr.Textbox( | |
label="Translation", | |
lines=5, | |
interactive=False | |
) | |
reward_score = gr.Number( | |
label="Translation Quality Score (higher is better)", | |
precision=4, | |
interactive=False | |
) | |
with gr.Row(): | |
score_indicator = gr.Label(label="Quality Rating") | |
# Function to update the quality rating based on score | |
def update_quality_rating(score): | |
if score >= 0.8: | |
return "Excellent" | |
elif score >= 0.6: | |
return "Good" | |
elif score >= 0.4: | |
return "Average" | |
elif score >= 0.2: | |
return "Poor" | |
else: | |
return "Very Poor" | |
# Set up the translation flow | |
translate_outputs = translate_button.click( | |
fn=translate_text, | |
inputs=[source_text, target_language], | |
outputs=[translation_output, reward_score] | |
) | |
# Update the quality rating whenever the reward score changes | |
reward_score.change( | |
fn=update_quality_rating, | |
inputs=[reward_score], | |
outputs=[score_indicator] | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
["你好,世界!", "English"], | |
["我喜欢学习新的语言。", "Spanish"], | |
["北京烤鴨很好吃。", "French"], | |
["人工智能正在改变世界。", "German"], | |
["今天天气真好。", "Japanese"] | |
], | |
inputs=[source_text, target_language], | |
outputs=[translation_output, reward_score], | |
fn=translate_text | |
) | |
gr.Markdown("## How It Works") | |
gr.Markdown(""" | |
1. Enter Chinese text in the input box | |
2. Select your desired target language | |
3. Click 'Translate' to get the translation | |
4. The system will display the translation and a quality score | |
The quality score is generated by a reward model trained to evaluate translation quality. | |
Higher scores indicate better translations. | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |