File size: 12,773 Bytes
ecc4278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import torch
from modules import sd_samplers_kdiffusion, sd_samplers_common

from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
from ldm_patched.modules.samplers import calculate_sigmas_scheduler
from modules import shared


ADAPTIVE_SOLVERS = {"dopri8", "dopri5", "bosh3", "fehlberg2", "adaptive_heun"}
FIXED_SOLVERS = {"euler", "midpoint", "rk4", "heun3", "explicit_adams", "implicit_adams"}
ALL_SOLVERS = list(ADAPTIVE_SOLVERS | FIXED_SOLVERS)
ALL_SOLVERS.sort()

class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
    def __init__(self, sd_model, sampler_name, solver=None, rtol=None, atol=None):
        self.sampler_name = sampler_name
        self.scheduler_name = None
        self.unet = sd_model.forge_objects.unet
        self.model = sd_model
        self.solver = solver
        self.rtol = rtol
        self.atol = atol
        
        sampler_functions = {
            'euler_comfy': k_diffusion_sampling.sample_euler,
            'euler_ancestral_comfy': k_diffusion_sampling.sample_euler_ancestral,
            'heun_comfy': k_diffusion_sampling.sample_heun,
            'dpmpp_2s_ancestral_comfy': k_diffusion_sampling.sample_dpmpp_2s_ancestral,
            'dpmpp_sde_comfy': k_diffusion_sampling.sample_dpmpp_sde,
            'dpmpp_2m_comfy': k_diffusion_sampling.sample_dpmpp_2m,
            'dpmpp_2m_sde_comfy': k_diffusion_sampling.sample_dpmpp_2m_sde,
            'dpmpp_3m_sde_comfy': k_diffusion_sampling.sample_dpmpp_3m_sde,
            'euler_ancestral_turbo': k_diffusion_sampling.sample_euler_ancestral,
            'dpmpp_2m_turbo': k_diffusion_sampling.sample_dpmpp_2m,
            'dpmpp_2m_sde_turbo': k_diffusion_sampling.sample_dpmpp_2m_sde,
            'ddpm': k_diffusion_sampling.sample_ddpm,
            'heunpp2': k_diffusion_sampling.sample_heunpp2,
            'ipndm': k_diffusion_sampling.sample_ipndm,
            'ipndm_v': k_diffusion_sampling.sample_ipndm_v,
            'deis': k_diffusion_sampling.sample_deis,
            'euler_cfg_pp': k_diffusion_sampling.sample_euler_cfg_pp,
            'euler_ancestral_cfg_pp': k_diffusion_sampling.sample_euler_ancestral_cfg_pp,
            'dpmpp_2s_ancestral_cfg_pp': k_diffusion_sampling.sample_dpmpp_2s_ancestral_cfg_pp,
            'dpmpp_2s_ancestral_cfg_pp_dyn': k_diffusion_sampling.sample_dpmpp_2s_ancestral_cfg_pp_dyn,
            'dpmpp_2s_ancestral_cfg_pp_intern': k_diffusion_sampling.sample_dpmpp_2s_ancestral_cfg_pp_intern,
            'dpmpp_sde_cfg_pp': k_diffusion_sampling.sample_dpmpp_sde_cfg_pp,
            'dpmpp_2m_cfg_pp': k_diffusion_sampling.sample_dpmpp_2m_cfg_pp,
            'ode_bosh3': self.sample_ode_bosh3,
            'ode_fehlberg2': self.sample_ode_fehlberg2,
            'ode_adaptive_heun': self.sample_ode_adaptive_heun,
            'ode_dopri5': self.sample_ode_dopri5,
            'ode_custom':self.sample_ode_custom,
        }
        
        sampler_function = sampler_functions.get(sampler_name)
        if sampler_function is None:
            raise ValueError(f"Unknown sampler: {sampler_name}")
        
        super().__init__(sampler_function, sd_model, None)

    def sample_func(self, model, x, sigmas, extra_args=None, callback=None, disable=None):
        if self.sampler_name == 'ode_bosh3':
            return self.sample_ode_bosh3(model, x, sigmas, extra_args, callback, disable)
        elif self.sampler_name == 'ode_fehlberg2':
            return self.sample_ode_fehlberg2(model, x, sigmas, extra_args, callback, disable)
        elif self.sampler_name == 'ode_adaptive_heun':
            return self.sample_ode_adaptive_heun(model, x, sigmas, extra_args, callback, disable)
        elif self.sampler_name == 'ode_dopri5':
            return self.sample_ode_dopri5(model, x, sigmas, extra_args, callback, disable)
        elif self.sampler_name == 'ode_custom':
            return self.sample_ode_custom(model, x, sigmas, extra_args, callback, disable)
        else:
            # For non-ODE samplers, use the original sampler function
            return super().sample_func(model, x, sigmas, extra_args, callback, disable)

    def sample_ode_bosh3(self, model, x, sigmas, extra_args=None, callback=None, disable=None):
        return k_diffusion_sampling.sample_ode(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
                                               solver="bosh3", 
                                               rtol=10**shared.opts.ode_bosh3_rtol, 
                                               atol=10**shared.opts.ode_bosh3_atol, 
                                               max_steps=shared.opts.ode_bosh3_max_steps)

    def sample_ode_fehlberg2(self, model, x, sigmas, extra_args=None, callback=None, disable=None):
        return k_diffusion_sampling.sample_ode(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
                                               solver="fehlberg2", 
                                               rtol=10**shared.opts.ode_fehlberg2_rtol, 
                                               atol=10**shared.opts.ode_fehlberg2_atol, 
                                               max_steps=shared.opts.ode_fehlberg2_max_steps)

    def sample_ode_adaptive_heun(self, model, x, sigmas, extra_args=None, callback=None, disable=None):
        return k_diffusion_sampling.sample_ode(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
                                               solver="adaptive_heun", 
                                               rtol=10**shared.opts.ode_adaptive_heun_rtol, 
                                               atol=10**shared.opts.ode_adaptive_heun_atol, 
                                               max_steps=shared.opts.ode_adaptive_heun_max_steps)

    def sample_ode_dopri5(self, model, x, sigmas, extra_args=None, callback=None, disable=None):
        return k_diffusion_sampling.sample_ode(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
                                               solver="dopri5", 
                                               rtol=10**shared.opts.ode_dopri5_rtol, 
                                               atol=10**shared.opts.ode_dopri5_atol, 
                                               max_steps=shared.opts.ode_dopri5_max_steps)

    def sample_ode_custom(self, model, x, sigmas, extra_args=None, callback=None, disable=None):
        solver = shared.opts.ode_custom_solver
        rtol = 10**shared.opts.ode_custom_rtol if solver in ADAPTIVE_SOLVERS else None
        atol = 10**shared.opts.ode_custom_atol if solver in ADAPTIVE_SOLVERS else None
        max_steps = shared.opts.ode_custom_max_steps
        
        return k_diffusion_sampling.sample_ode(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
                                               solver=solver, rtol=rtol, atol=atol, max_steps=max_steps)
    
    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
        self.scheduler_name = p.scheduler
        return super().sample(p, x, conditioning, unconditional_conditioning, steps, image_conditioning)

    def get_sigmas(self, p, steps):
        
        if self.scheduler_name is None:
            self.scheduler_name = 'React Cosinusoidal DynSF'  # Default to 'Normal' if not set

        forge_schedulers = {
            "Normal": "normal",
            "Karras": "karras",
            "Exponential": "exponential",
            "SGM Uniform": "sgm_uniform",
            "Simple": "simple",
            "DDIM": "ddim_uniform",
            "Align Your Steps": "ays",
            "Align Your Steps GITS": "ays_gits",
            "Align Your Steps 11": "ays_11steps",
            "Align Your Steps 32": "ays_32steps",
            "KL Optimal": "kl_optimal",
            "Beta": "beta",
            "Sinusoidal SF":"sinusoidal_sf",
            "Invcosinusoidal SF":"invcosinusoidal_sf",
            "React Cosinusoidal DynSF":"react_cosinusoidal_dynsf"
        }
        
        if self.scheduler_name in forge_schedulers:
            matched_scheduler = forge_schedulers[self.scheduler_name]
        else:
            # Default to 'normal' if the selected scheduler is not available in forge_alter
            matched_scheduler = 'React Cosinusoidal DynSF'

        if self.sampler_name.endswith('_turbo'):
            # Use Turbo scheduler for Turbo samplers
            timesteps = torch.flip(torch.arange(1, steps + 1) * float(1000.0 / steps) - 1, (0,)).round().long().clip(0, 999)
            sigmas = self.unet.model.model_sampling.sigma(timesteps)
            sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
        else:
            sigmas = calculate_sigmas_scheduler(self.unet.model, matched_scheduler, steps, is_sdxl=getattr(self.model, "is_sdxl", False))
        
        return sigmas.to(self.unet.load_device)


def build_constructor(sampler_name):
    def constructor(model):
        return AlterSampler(model, sampler_name)
    return constructor

samplers_data_alter = [
    sd_samplers_common.SamplerData('Euler Comfy', build_constructor(sampler_name='euler_comfy'), ['euler_comfy'], {}),
    sd_samplers_common.SamplerData('Euler A Comfy', build_constructor(sampler_name='euler_ancestral_comfy'), ['euler_ancestral_comfy'], {}),
    sd_samplers_common.SamplerData('Heun Comfy', build_constructor(sampler_name='heun_comfy'), ['heun_comfy'], {}),
    sd_samplers_common.SamplerData('DPM++ 2S Ancestral Comfy', build_constructor(sampler_name='dpmpp_2s_ancestral_comfy'), ['dpmpp_2s_ancestral_comfy'], {}),
    sd_samplers_common.SamplerData('DPM++ SDE Comfy', build_constructor(sampler_name='dpmpp_sde_comfy'), ['dpmpp_sde_comfy'], {}),
    sd_samplers_common.SamplerData('DPM++ 2M Comfy', build_constructor(sampler_name='dpmpp_2m_comfy'), ['dpmpp_2m_comfy'], {}),
    sd_samplers_common.SamplerData('DPM++ 2M SDE Comfy', build_constructor(sampler_name='dpmpp_2m_sde_comfy'), ['dpmpp_2m_sde_comfy'], {}),
    sd_samplers_common.SamplerData('DPM++ 3M SDE Comfy', build_constructor(sampler_name='dpmpp_3m_sde_comfy'), ['dpmpp_3m_sde_comfy'], {}),
    sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral_turbo'), ['euler_ancestral_turbo'], {}),
    sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m_turbo'), ['dpmpp_2m_turbo'], {}),
    sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde_turbo'), ['dpmpp_2m_sde_turbo'], {}),
    sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
    sd_samplers_common.SamplerData('HeunPP2', build_constructor(sampler_name='heunpp2'), ['heunpp2'], {}),
    sd_samplers_common.SamplerData('IPNDM', build_constructor(sampler_name='ipndm'), ['ipndm'], {}),
    sd_samplers_common.SamplerData('IPNDM_V', build_constructor(sampler_name='ipndm_v'), ['ipndm_v'], {}),
    sd_samplers_common.SamplerData('DEIS', build_constructor(sampler_name='deis'), ['deis'], {}),
    sd_samplers_common.SamplerData('Euler CFG++', build_constructor(sampler_name='euler_cfg_pp'), ['euler_cfg_pp'], {}),
    sd_samplers_common.SamplerData('Euler Ancestral CFG++', build_constructor(sampler_name='euler_ancestral_cfg_pp'), ['euler_ancestral_cfg_pp'], {}),
    sd_samplers_common.SamplerData('DPM++ 2S Ancestral CFG++', build_constructor(sampler_name='dpmpp_2s_ancestral_cfg_pp'), ['dpmpp_2s_ancestral_cfg_pp'], {}),
    sd_samplers_common.SamplerData('DPM++ 2S Ancestral CFG++ Dyn', build_constructor(sampler_name='dpmpp_2s_ancestral_cfg_pp_dyn'), ['dpmpp_2s_ancestral_cfg_pp_dyn'], {}),
    sd_samplers_common.SamplerData('DPM++ 2S Ancestral CFG++ Intern', build_constructor(sampler_name='dpmpp_2s_ancestral_cfg_pp_intern'), ['dpmpp_2s_ancestral_cfg_pp_intern'], {}),
    sd_samplers_common.SamplerData('DPM++ SDE CFG++', build_constructor(sampler_name='dpmpp_sde_cfg_pp'), ['dpmpp_sde_cfg_pp'], {}),
    sd_samplers_common.SamplerData('DPM++ 2M CFG++', build_constructor(sampler_name='dpmpp_2m_cfg_pp'), ['dpmpp_2m_cfg_pp'], {}),
    sd_samplers_common.SamplerData('ODE (Bosh3)', build_constructor(sampler_name='ode_bosh3'), ['ode_bosh3'], {}),
    sd_samplers_common.SamplerData('ODE (Fehlberg2)', build_constructor(sampler_name='ode_fehlberg2'), ['ode_fehlberg2'], {}),
    sd_samplers_common.SamplerData('ODE (Adaptive Heun)', build_constructor(sampler_name='ode_adaptive_heun'), ['ode_adaptive_heun'], {}),
    sd_samplers_common.SamplerData('ODE (Dopri5)', build_constructor(sampler_name='ode_dopri5'), ['ode_dopri5'], {}),
    sd_samplers_common.SamplerData('ODE Custom', build_constructor(sampler_name='ode_custom'), ['ode_custom'], {}),
]