|
import gc |
|
import unittest |
|
|
|
import numpy as np |
|
import pytest |
|
import torch |
|
|
|
from diffusers import FluxPipeline, FluxPriorReduxPipeline |
|
from diffusers.utils import load_image |
|
from diffusers.utils.testing_utils import ( |
|
numpy_cosine_similarity_distance, |
|
require_big_gpu_with_torch_cuda, |
|
slow, |
|
torch_device, |
|
) |
|
|
|
|
|
@slow |
|
@require_big_gpu_with_torch_cuda |
|
@pytest.mark.big_gpu_with_torch_cuda |
|
class FluxReduxSlowTests(unittest.TestCase): |
|
pipeline_class = FluxPriorReduxPipeline |
|
repo_id = "YiYiXu/yiyi-redux" |
|
base_pipeline_class = FluxPipeline |
|
base_repo_id = "black-forest-labs/FLUX.1-schnell" |
|
|
|
def setUp(self): |
|
super().setUp() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def tearDown(self): |
|
super().tearDown() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def get_inputs(self, device, seed=0): |
|
init_image = load_image( |
|
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png" |
|
) |
|
return {"image": init_image} |
|
|
|
def get_base_pipeline_inputs(self, device, seed=0): |
|
if str(device).startswith("mps"): |
|
generator = torch.manual_seed(seed) |
|
else: |
|
generator = torch.Generator(device="cpu").manual_seed(seed) |
|
|
|
return { |
|
"num_inference_steps": 2, |
|
"guidance_scale": 2.0, |
|
"output_type": "np", |
|
"generator": generator, |
|
} |
|
|
|
def test_flux_redux_inference(self): |
|
pipe_redux = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) |
|
pipe_base = self.base_pipeline_class.from_pretrained( |
|
self.base_repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None |
|
) |
|
pipe_redux.to(torch_device) |
|
pipe_base.enable_model_cpu_offload() |
|
|
|
inputs = self.get_inputs(torch_device) |
|
base_pipeline_inputs = self.get_base_pipeline_inputs(torch_device) |
|
|
|
redux_pipeline_output = pipe_redux(**inputs) |
|
image = pipe_base(**base_pipeline_inputs, **redux_pipeline_output).images[0] |
|
|
|
image_slice = image[0, :10, :10] |
|
expected_slice = np.array( |
|
[ |
|
0.30078125, |
|
0.37890625, |
|
0.46875, |
|
0.28125, |
|
0.36914062, |
|
0.47851562, |
|
0.28515625, |
|
0.375, |
|
0.4765625, |
|
0.28125, |
|
0.375, |
|
0.48046875, |
|
0.27929688, |
|
0.37695312, |
|
0.47851562, |
|
0.27734375, |
|
0.38085938, |
|
0.4765625, |
|
0.2734375, |
|
0.38085938, |
|
0.47265625, |
|
0.27539062, |
|
0.37890625, |
|
0.47265625, |
|
0.27734375, |
|
0.37695312, |
|
0.47070312, |
|
0.27929688, |
|
0.37890625, |
|
0.47460938, |
|
], |
|
dtype=np.float32, |
|
) |
|
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) |
|
|
|
assert max_diff < 1e-4 |
|
|