staffmanager-llama4-scout / test_llama4.py
cpg716's picture
Create test_llama4.py
b604130 verified
import torch
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import time
import os
from huggingface_hub import login
import requests
from PIL import Image
from io import BytesIO
# Print versions for debugging
import sys
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
import transformers
print(f"Transformers version: {transformers.__version__}")
# Get token from environment
token = os.environ.get("HUGGINGFACE_TOKEN", "")
if token:
print(f"Token found: {token[:5]}...")
else:
print("No token found in environment variables!")
# Login to Hugging Face
try:
login(token=token)
print("Successfully logged in to Hugging Face Hub")
except Exception as e:
print(f"Error logging in: {e}")
# Test 1: Simple text generation with Llama 4
def test_text_generation():
print("\n=== Testing Text Generation ===")
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Llama-4-8B-Instruct" # Using smaller model for faster testing
print(f"Loading tokenizer from {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
print(f"Loading model from {model_id}...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=token,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print("Model and tokenizer loaded successfully!")
# Simple prompt
prompt = "Write a short poem about artificial intelligence."
print(f"Generating text for prompt: '{prompt}'")
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
start_time = time.time()
outputs = model.generate(**inputs, max_new_tokens=100)
end_time = time.time()
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generation completed in {end_time - start_time:.2f} seconds")
print(f"Result: {result}")
return True
except Exception as e:
print(f"Error in text generation test: {e}")
import traceback
print(traceback.format_exc())
return False
# Test 2: Image-text generation with Llama 4 Scout
def test_image_text_generation():
print("\n=== Testing Image-Text Generation ===")
try:
model_id = "meta-llama/Llama-4-Scout-8B-16E-Instruct" # Using smaller model for faster testing
print(f"Loading processor from {model_id}...")
processor = AutoProcessor.from_pretrained(model_id, token=token)
print(f"Loading model from {model_id}...")
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
token=token,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print("Model and processor loaded successfully!")
# Load a test image
print("Loading test image...")
response = requests.get("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg")
img = Image.open(BytesIO(response.content))
print(f"Image loaded: {img.size}")
# Simple prompt
prompt = "Describe this image in two sentences."
print(f"Creating messages with prompt: '{prompt}'")
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "data:image/jpeg;base64," + BytesIO(response.content).getvalue().hex()},
{"type": "text", "text": prompt},
]
},
]
print("Applying chat template...")
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
print("Generating response...")
start_time = time.time()
outputs = model.generate(**inputs, max_new_tokens=100)
end_time = time.time()
result = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print(f"Generation completed in {end_time - start_time:.2f} seconds")
print(f"Result: {result}")
return True
except Exception as e:
print(f"Error in image-text generation test: {e}")
import traceback
print(traceback.format_exc())
return False
if __name__ == "__main__":
print("Starting Llama 4 tests...")
# Run text generation test
text_success = test_text_generation()
# Run image-text generation test if text test succeeds
if text_success:
image_text_success = test_image_text_generation()
else:
print("Skipping image-text test due to text test failure")
image_text_success = False
# Summary
print("\n=== Test Summary ===")
print(f"Text Generation Test: {'SUCCESS' if text_success else 'FAILED'}")
print(f"Image-Text Generation Test: {'SUCCESS' if image_text_success else 'FAILED'}")
if text_success and image_text_success:
print("\nAll tests passed! Your Llama 4 Scout setup is working correctly.")
else:
print("\nSome tests failed. Please check the error messages above.")