Spaces:
Runtime error
Runtime error
File size: 7,649 Bytes
135276f 57d7889 a024afa cb62b20 135276f cb62b20 135276f cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 57d7889 cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa cb62b20 a024afa 57d7889 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from safetensors.torch import load_file
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
# Constants
THRESHOLD = 2 # From Plan2Align
# Initialize device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load models once
print("Loading models...")
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16
)
class RewardModel:
def __init__(self, device, tokenizer, torch_dtype=torch.float16):
self.device = device
self.tokenizer = tokenizer
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Set chat template if not already set
if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None:
# Using Llama 3's default chat template
self.tokenizer.chat_template = "<|begin_of_text|>{% for message in messages %}{{'<|start_header_id|>' + message['role'] + '<|end_header_id|>\n' + message['content'] + '<|eot_id|>'}}{% endfor %}"
print("Loading reward model...")
self.RM = AutoModelForCausalLMWithValueHead.from_pretrained(
"ray24724919/plan2align_rm",
device_map={"": 0}, # Force model to stay on GPU
torch_dtype=torch_dtype
)
self.RM.eval()
print("Reward model loaded successfully!")
def _create_single_message(self, language, source, translation):
return [
{
"role": "system",
"content": "You are a helpful translator and only output the result."
},
{
"role": "user",
"content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:"
},
{
"role": "assistant",
"content": translation
}
]
def _process_inputs(self, messages):
try:
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=False,
return_tensors="pt",
padding=True,
truncation=True
)
attention_mask = torch.ones_like(input_ids)
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
if len(input_ids.shape) == 1:
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
return {
"input_ids": input_ids,
"attention_mask": attention_mask
}
except Exception as e:
logging.error(f"Error processing inputs: {str(e)}")
raise
def reward_fn(self, language, source, translations):
try:
all_rewards = []
for translation in translations:
messages = self._create_single_message(language, source, translation)
inputs = self._process_inputs(messages)
with torch.no_grad():
outputs = self.RM(**inputs, return_value=True)
rewards = outputs[2]
reward = rewards[0, -1].cpu().item()
all_rewards.append(reward)
return all_rewards
except Exception as e:
logging.error(f"Error in reward_fn: {str(e)}")
raise
def get_len(self, language, translations):
try:
len_ = 0
for translation in translations:
l = self.tokenizer(translation, return_tensors="pt").input_ids.to(device).shape[-1]
len_ += l
return len_
except Exception as e:
logging.error(f"Error in get_len: {str(e)}")
raise
# Create reward model instance with the already loaded tokenizer
reward_model = RewardModel(device, tokenizer, torch_dtype=torch.float16)
print("Models loaded successfully!")
# Helper functions from Plan2Align
def rm_predict_preference(source, translation0, translation1, language="English"):
translations = [translation0, translation1]
for t_i in range(len(translations)):
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ')
rewards = reward_model.reward_fn(language, source.replace('</s>',' '), translations)
best_index = rewards.index(max(rewards))
return best_index
def rm_find_best_translation(source, translations, language="English"):
copy_translations = translations.copy()
if len(translations) < 2:
return translations[0] if translations else None
for t_i in range(len(translations)):
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ')
rewards = reward_model.reward_fn(language, ''.join(source).replace('</s>',' '), translations)
print(rewards)
best_index = rewards.index(max(rewards))
print(f"Total translations length = {len(translations)}, and best translation index is: {best_index}")
if rewards[best_index] >= THRESHOLD:
return copy_translations[best_index]
else:
return None
def translate_chinese_to_english(chinese_text):
# Generate multiple translations
translations = []
# Generate three different translations with different system prompts
system_prompts = [
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
"You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story."
]
for prompt in system_prompts:
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": f"Translate the following Chinese text to English:\n\n{chinese_text}"}
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
outputs = model.generate(
inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True
)
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
translations.append(translation)
# Use reward model to find the best translation
best_translation = rm_find_best_translation(chinese_text, translations)
if best_translation is None:
# If no translation meets the threshold, return the first one
return translations[0]
return best_translation
# Gradio interface
def process_text(text):
return translate_chinese_to_english(text)
demo = gr.Interface(
fn=process_text,
inputs=gr.Textbox(lines=5, placeholder="Enter Chinese text here..."),
outputs=gr.Textbox(lines=5),
title="Chinese to English Translation with Plan2Align",
description="This app uses the Plan2Align approach to translate Chinese text to English."
)
if __name__ == "__main__":
demo.launch()
|