import os import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM from trl import AutoModelForCausalLMWithValueHead from safetensors.torch import load_file import logging from huggingface_hub import login # Set up logging login(token=os.environ.get("LA_NAME")) # Constants THRESHOLD = 2 # From Plan2Align # Initialize device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load models once print("Loading models...") model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16 ) class RewardModel: def __init__(self, device, tokenizer, torch_dtype=torch.float16): self.device = device self.tokenizer = tokenizer if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Set chat template if not already set if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None: # Using Llama 3's default chat template self.tokenizer.chat_template = "<|begin_of_text|>{% for message in messages %}{{'<|start_header_id|>' + message['role'] + '<|end_header_id|>\n' + message['content'] + '<|eot_id|>'}}{% endfor %}" print("Loading reward model...") self.RM = AutoModelForCausalLMWithValueHead.from_pretrained( "ray24724919/plan2align_rm", device_map={"": 0}, # Force model to stay on GPU torch_dtype=torch_dtype ) self.RM.eval() print("Reward model loaded successfully!") def _create_single_message(self, language, source, translation): return [ { "role": "system", "content": "You are a helpful translator and only output the result." }, { "role": "user", "content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:" }, { "role": "assistant", "content": translation } ] def _process_inputs(self, messages): try: input_ids = self.tokenizer.apply_chat_template( messages, add_generation_prompt=False, return_tensors="pt", padding=True, truncation=True ) attention_mask = torch.ones_like(input_ids) input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) if len(input_ids.shape) == 1: input_ids = input_ids.unsqueeze(0) attention_mask = attention_mask.unsqueeze(0) return { "input_ids": input_ids, "attention_mask": attention_mask } except Exception as e: logging.error(f"Error processing inputs: {str(e)}") raise def reward_fn(self, language, source, translations): try: all_rewards = [] for translation in translations: messages = self._create_single_message(language, source, translation) inputs = self._process_inputs(messages) with torch.no_grad(): outputs = self.RM(**inputs, return_value=True) rewards = outputs[2] reward = rewards[0, -1].cpu().item() all_rewards.append(reward) return all_rewards except Exception as e: logging.error(f"Error in reward_fn: {str(e)}") raise def get_len(self, language, translations): try: len_ = 0 for translation in translations: l = self.tokenizer(translation, return_tensors="pt").input_ids.to(device).shape[-1] len_ += l return len_ except Exception as e: logging.error(f"Error in get_len: {str(e)}") raise # Create reward model instance with the already loaded tokenizer reward_model = RewardModel(device, tokenizer, torch_dtype=torch.float16) print("Models loaded successfully!") # Helper functions from Plan2Align def rm_predict_preference(source, translation0, translation1, language="English"): translations = [translation0, translation1] for t_i in range(len(translations)): translations[t_i] = ''.join(translations[t_i]).replace('',' ') rewards = reward_model.reward_fn(language, source.replace('',' '), translations) best_index = rewards.index(max(rewards)) return best_index def rm_find_best_translation(source, translations, language="English"): copy_translations = translations.copy() if len(translations) < 2: return translations[0] if translations else None for t_i in range(len(translations)): translations[t_i] = ''.join(translations[t_i]).replace('',' ') rewards = reward_model.reward_fn(language, ''.join(source).replace('',' '), translations) print(rewards) best_index = rewards.index(max(rewards)) print(f"Total translations length = {len(translations)}, and best translation index is: {best_index}") if rewards[best_index] >= THRESHOLD: return copy_translations[best_index] else: return None def translate_chinese_to_english(chinese_text): # Generate multiple translations translations = [] # Generate three different translations with different system prompts system_prompts = [ "You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.", "You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.", "You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story." ] for prompt in system_prompts: messages = [ {"role": "system", "content": prompt}, {"role": "user", "content": f"Translate the following Chinese text to English:\n\n{chinese_text}"} ] inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) outputs = model.generate( inputs, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True ) translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) translations.append(translation) # Use reward model to find the best translation best_translation = rm_find_best_translation(chinese_text, translations) if best_translation is None: # If no translation meets the threshold, return the first one return translations[0] return best_translation # Gradio interface def process_text(text): return translate_chinese_to_english(text) demo = gr.Interface( fn=process_text, inputs=gr.Textbox(lines=5, placeholder="Enter Chinese text here..."), outputs=gr.Textbox(lines=5), title="Chinese to English Translation with Plan2Align", description="This app uses the Plan2Align approach to translate Chinese text to English." ) if __name__ == "__main__": demo.launch()