Spaces:
Runtime error
Runtime error
import torch | |
import safetensors.torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from trl import AutoModelForCausalLMWithValueHead | |
# Set device and dtype | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch_dtype = torch.bfloat16 | |
# Load the base LLaMa 3.1 8B model for translation | |
model_id = "meta-llama/Meta-Llama-3.1-8B" # Replace with your actual model ID | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
lm_model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
device_map="auto" | |
) | |
# Load the reward model | |
RM = AutoModelForCausalLMWithValueHead.from_pretrained( | |
'ray24724919/plan2align_rm', | |
torch_dtype=torch_dtype, | |
device_map="auto" | |
) | |
RM.eval() | |
RM.gradient_checkpointing_enable() # if needed for memory efficiency | |
# Define the load_file function | |
def load_file(file_path): | |
return safetensors.torch.load_file(file_path) | |
# Load value head weights if you have the file | |
# If you don't have the specific file, you might need to download it or use the model as is | |
try: | |
value_head_weights = load_file("value_head.safetensors") # Replace with actual path | |
new_state_dict = {key.replace("v_head.", "") if key.startswith("v_head.") else key: value for key, value in value_head_weights.items()} | |
RM.v_head.load_state_dict(new_state_dict) | |
except FileNotFoundError: | |
print("Value head weights file not found. Using default weights.") | |
# Define translation function with more flexibility | |
def translate(source_text, target_language="English", model=lm_model): | |
""" | |
Translate text from Chinese to the specified target language. | |
Args: | |
source_text (str): The Chinese text to translate | |
target_language (str): The target language for translation | |
model: The model to use for translation | |
Returns: | |
str: The translated text | |
""" | |
# Format the input as per the system prompt | |
messages = [ | |
{"role": "system", "content": "You are a helpful translator and only output the result."}, | |
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"} | |
] | |
# Format messages for the model | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
# Tokenize the input | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Generate translation | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode the generated text | |
translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip() | |
return translation | |
# Evaluate the translation using the reward model | |
def evaluate_translation(source_text, translation, target_language="English"): | |
""" | |
Evaluate the quality of a translation using the reward model. | |
Args: | |
source_text (str): The original Chinese text | |
translation (str): The translated text | |
target_language (str): The target language of the translation | |
Returns: | |
float: The reward score | |
""" | |
messages = [ | |
{"role": "system", "content": "You are a helpful translator and only output the result."}, | |
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}, | |
{"role": "assistant", "content": translation} | |
] | |
# Format messages for the reward model | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False) | |
# Tokenize the input | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Get reward score | |
with torch.no_grad(): | |
outputs = RM(input_ids=inputs.input_ids) | |
reward_score = outputs.value.item() | |
return reward_score | |
# Function to translate and evaluate in one step | |
def translate_and_evaluate(source_text, target_language="English"): | |
""" | |
Translate text and evaluate the translation quality in one step. | |
Args: | |
source_text (str): The Chinese text to translate | |
target_language (str): The target language for translation | |
Returns: | |
tuple: (translation, reward_score) | |
""" | |
translation = translate(source_text, target_language) | |
reward_score = evaluate_translation(source_text, translation, target_language) | |
return translation, reward_score | |
# Example usage | |
if __name__ == "__main__": | |
# Example with default target language (English) | |
source = "你好世界" | |
translation, reward_score = translate_and_evaluate(source) | |
print(f"Source: {source}") | |
print(f"Translation to English: {translation}") | |
print(f"Reward Score: {reward_score}") | |
# Example with custom target language | |
target_language = "French" | |
translation, reward_score = translate_and_evaluate(source, target_language) | |
print(f"\nSource: {source}") | |
print(f"Translation to {target_language}: {translation}") | |
print(f"Reward Score: {reward_score}") | |
# Interactive mode | |
print("\n=== Interactive Translation Mode ===") | |
print("Enter 'quit' to exit") | |
while True: | |
user_input = input("\nEnter Chinese text to translate: ") | |
if user_input.lower() == 'quit': | |
break | |
target = input("Enter target language (default: English): ").strip() | |
if not target: | |
target = "English" | |
translation, reward_score = translate_and_evaluate(user_input, target) | |
print(f"Translation to {target}: {translation}") | |
print(f"Reward Score: {reward_score}") |