Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from trl import AutoModelForCausalLMWithValueHead | |
from safetensors.torch import load_file | |
import logging | |
from huggingface_hub import login | |
# Set up logging | |
login(token=os.environ.get("LA_NAME")) | |
# Constants | |
THRESHOLD = 2 # From Plan2Align | |
# Initialize device | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load models once | |
print("Loading models...") | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
class RewardModel: | |
def __init__(self, device, tokenizer, torch_dtype=torch.float16): | |
self.device = device | |
self.tokenizer = tokenizer | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Set chat template if not already set | |
if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None: | |
# Using Llama 3's default chat template | |
self.tokenizer.chat_template = "<|begin_of_text|>{% for message in messages %}{{'<|start_header_id|>' + message['role'] + '<|end_header_id|>\n' + message['content'] + '<|eot_id|>'}}{% endfor %}" | |
print("Loading reward model...") | |
self.RM = AutoModelForCausalLMWithValueHead.from_pretrained( | |
"ray24724919/plan2align_rm", | |
device_map={"": 0}, # Force model to stay on GPU | |
torch_dtype=torch_dtype | |
) | |
self.RM.eval() | |
print("Reward model loaded successfully!") | |
def _create_single_message(self, language, source, translation): | |
return [ | |
{ | |
"role": "system", | |
"content": "You are a helpful translator and only output the result." | |
}, | |
{ | |
"role": "user", | |
"content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:" | |
}, | |
{ | |
"role": "assistant", | |
"content": translation | |
} | |
] | |
def _process_inputs(self, messages): | |
try: | |
input_ids = self.tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=False, | |
return_tensors="pt", | |
padding=True, | |
truncation=True | |
) | |
attention_mask = torch.ones_like(input_ids) | |
input_ids = input_ids.to(self.device) | |
attention_mask = attention_mask.to(self.device) | |
if len(input_ids.shape) == 1: | |
input_ids = input_ids.unsqueeze(0) | |
attention_mask = attention_mask.unsqueeze(0) | |
return { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask | |
} | |
except Exception as e: | |
logging.error(f"Error processing inputs: {str(e)}") | |
raise | |
def reward_fn(self, language, source, translations): | |
try: | |
all_rewards = [] | |
for translation in translations: | |
messages = self._create_single_message(language, source, translation) | |
inputs = self._process_inputs(messages) | |
with torch.no_grad(): | |
outputs = self.RM(**inputs, return_value=True) | |
rewards = outputs[2] | |
reward = rewards[0, -1].cpu().item() | |
all_rewards.append(reward) | |
return all_rewards | |
except Exception as e: | |
logging.error(f"Error in reward_fn: {str(e)}") | |
raise | |
def get_len(self, language, translations): | |
try: | |
len_ = 0 | |
for translation in translations: | |
l = self.tokenizer(translation, return_tensors="pt").input_ids.to(device).shape[-1] | |
len_ += l | |
return len_ | |
except Exception as e: | |
logging.error(f"Error in get_len: {str(e)}") | |
raise | |
# Create reward model instance with the already loaded tokenizer | |
reward_model = RewardModel(device, tokenizer, torch_dtype=torch.float16) | |
print("Models loaded successfully!") | |
# Helper functions from Plan2Align | |
def rm_predict_preference(source, translation0, translation1, language="English"): | |
translations = [translation0, translation1] | |
for t_i in range(len(translations)): | |
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ') | |
rewards = reward_model.reward_fn(language, source.replace('</s>',' '), translations) | |
best_index = rewards.index(max(rewards)) | |
return best_index | |
def rm_find_best_translation(source, translations, language="English"): | |
copy_translations = translations.copy() | |
if len(translations) < 2: | |
return translations[0] if translations else None | |
for t_i in range(len(translations)): | |
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ') | |
rewards = reward_model.reward_fn(language, ''.join(source).replace('</s>',' '), translations) | |
print(rewards) | |
best_index = rewards.index(max(rewards)) | |
print(f"Total translations length = {len(translations)}, and best translation index is: {best_index}") | |
if rewards[best_index] >= THRESHOLD: | |
return copy_translations[best_index] | |
else: | |
return None | |
def translate_chinese_to_english(chinese_text): | |
# Generate multiple translations | |
translations = [] | |
# Generate three different translations with different system prompts | |
system_prompts = [ | |
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.", | |
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.", | |
"You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story." | |
] | |
for prompt in system_prompts: | |
messages = [ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": f"Translate the following Chinese text to English:\n\n{chinese_text}"} | |
] | |
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
translations.append(translation) | |
# Use reward model to find the best translation | |
best_translation = rm_find_best_translation(chinese_text, translations) | |
if best_translation is None: | |
# If no translation meets the threshold, return the first one | |
return translations[0] | |
return best_translation | |
# Gradio interface | |
def process_text(text): | |
return translate_chinese_to_english(text) | |
demo = gr.Interface( | |
fn=process_text, | |
inputs=gr.Textbox(lines=5, placeholder="Enter Chinese text here..."), | |
outputs=gr.Textbox(lines=5), | |
title="Chinese to English Translation with Plan2Align", | |
description="This app uses the Plan2Align approach to translate Chinese text to English." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |