huckiyang's picture
navie plan2align
cb62b20
raw
history blame
7.65 kB
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from safetensors.torch import load_file
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
# Constants
THRESHOLD = 2 # From Plan2Align
# Initialize device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load models once
print("Loading models...")
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16
)
class RewardModel:
def __init__(self, device, tokenizer, torch_dtype=torch.float16):
self.device = device
self.tokenizer = tokenizer
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Set chat template if not already set
if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None:
# Using Llama 3's default chat template
self.tokenizer.chat_template = "<|begin_of_text|>{% for message in messages %}{{'<|start_header_id|>' + message['role'] + '<|end_header_id|>\n' + message['content'] + '<|eot_id|>'}}{% endfor %}"
print("Loading reward model...")
self.RM = AutoModelForCausalLMWithValueHead.from_pretrained(
"ray24724919/plan2align_rm",
device_map={"": 0}, # Force model to stay on GPU
torch_dtype=torch_dtype
)
self.RM.eval()
print("Reward model loaded successfully!")
def _create_single_message(self, language, source, translation):
return [
{
"role": "system",
"content": "You are a helpful translator and only output the result."
},
{
"role": "user",
"content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:"
},
{
"role": "assistant",
"content": translation
}
]
def _process_inputs(self, messages):
try:
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=False,
return_tensors="pt",
padding=True,
truncation=True
)
attention_mask = torch.ones_like(input_ids)
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
if len(input_ids.shape) == 1:
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
return {
"input_ids": input_ids,
"attention_mask": attention_mask
}
except Exception as e:
logging.error(f"Error processing inputs: {str(e)}")
raise
def reward_fn(self, language, source, translations):
try:
all_rewards = []
for translation in translations:
messages = self._create_single_message(language, source, translation)
inputs = self._process_inputs(messages)
with torch.no_grad():
outputs = self.RM(**inputs, return_value=True)
rewards = outputs[2]
reward = rewards[0, -1].cpu().item()
all_rewards.append(reward)
return all_rewards
except Exception as e:
logging.error(f"Error in reward_fn: {str(e)}")
raise
def get_len(self, language, translations):
try:
len_ = 0
for translation in translations:
l = self.tokenizer(translation, return_tensors="pt").input_ids.to(device).shape[-1]
len_ += l
return len_
except Exception as e:
logging.error(f"Error in get_len: {str(e)}")
raise
# Create reward model instance with the already loaded tokenizer
reward_model = RewardModel(device, tokenizer, torch_dtype=torch.float16)
print("Models loaded successfully!")
# Helper functions from Plan2Align
def rm_predict_preference(source, translation0, translation1, language="English"):
translations = [translation0, translation1]
for t_i in range(len(translations)):
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ')
rewards = reward_model.reward_fn(language, source.replace('</s>',' '), translations)
best_index = rewards.index(max(rewards))
return best_index
def rm_find_best_translation(source, translations, language="English"):
copy_translations = translations.copy()
if len(translations) < 2:
return translations[0] if translations else None
for t_i in range(len(translations)):
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ')
rewards = reward_model.reward_fn(language, ''.join(source).replace('</s>',' '), translations)
print(rewards)
best_index = rewards.index(max(rewards))
print(f"Total translations length = {len(translations)}, and best translation index is: {best_index}")
if rewards[best_index] >= THRESHOLD:
return copy_translations[best_index]
else:
return None
def translate_chinese_to_english(chinese_text):
# Generate multiple translations
translations = []
# Generate three different translations with different system prompts
system_prompts = [
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
"You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story."
]
for prompt in system_prompts:
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": f"Translate the following Chinese text to English:\n\n{chinese_text}"}
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
outputs = model.generate(
inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True
)
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
translations.append(translation)
# Use reward model to find the best translation
best_translation = rm_find_best_translation(chinese_text, translations)
if best_translation is None:
# If no translation meets the threshold, return the first one
return translations[0]
return best_translation
# Gradio interface
def process_text(text):
return translate_chinese_to_english(text)
demo = gr.Interface(
fn=process_text,
inputs=gr.Textbox(lines=5, placeholder="Enter Chinese text here..."),
outputs=gr.Textbox(lines=5),
title="Chinese to English Translation with Plan2Align",
description="This app uses the Plan2Align approach to translate Chinese text to English."
)
if __name__ == "__main__":
demo.launch()