File size: 16,524 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import os
import json
import shutil
from typing import Dict, List, Union
import gradio as gr


def get_recursively(d: Union[Dict, List], *args):
    if len(args) == 0:
        return d
    return get_recursively(d.get(args[0]), *args[1:])


def create_ui():
    from modules.ui_common import create_refresh_button
    from modules.ui_components import DropdownMulti
    from modules.shared import log, opts, cmd_opts, refresh_checkpoints
    from modules.sd_models import checkpoint_tiles, get_closet_checkpoint_match
    from modules.paths import sd_configs_path
    from . import run_olive_workflow
    from .execution_providers import ExecutionProvider, install_execution_provider
    from .utils import check_diffusers_cache

    with gr.Blocks(analytics_enabled=False) as ui:
        with gr.Tabs(elem_id="tabs_onnx"):
            with gr.TabItem("Provider", id="onnxep"):
                gr.Markdown("Install ONNX execution provider")
                ep_default = None
                if cmd_opts.use_directml:
                    ep_default = ExecutionProvider.DirectML
                elif cmd_opts.use_cuda:
                    ep_default = ExecutionProvider.CUDA
                elif cmd_opts.use_rocm:
                    ep_default = ExecutionProvider.ROCm
                elif cmd_opts.use_openvino:
                    ep_default = ExecutionProvider.OpenVINO
                ep_checkbox = gr.Radio(label="Execution provider", value=ep_default, choices=ExecutionProvider)
                ep_install = gr.Button(value="Reinstall")
                ep_log = gr.HTML("")
                ep_install.click(fn=install_execution_provider, inputs=[ep_checkbox], outputs=[ep_log])

            if run_olive_workflow is not None:
                import olive.passes as olive_passes
                from olive.hardware.accelerator import AcceleratorSpec, Device
                accelerator = AcceleratorSpec(accelerator_type=Device.GPU, execution_provider=opts.onnx_execution_provider)

                with gr.TabItem("Manage cache", id="manage_cache"):
                    cache_state_dirname = gr.Textbox(value=None, visible=False)
                    with gr.Row():
                        model_dropdown = gr.Dropdown(label="Model", value="Please select model", choices=checkpoint_tiles())
                        create_refresh_button(model_dropdown, refresh_checkpoints, {}, "onnx_cache_refresh_diffusers_model")
                    with gr.Row():
                        def remove_cache_onnx_converted(dirname: str):
                            shutil.rmtree(os.path.join(opts.onnx_cached_models_path, dirname))
                            log.info(f"ONNX converted cache of '{dirname}' is removed.")
                        cache_onnx_converted = gr.Markdown("Please select model")
                        cache_remove_onnx_converted = gr.Button(value="Remove cache", visible=False)
                        cache_remove_onnx_converted.click(fn=remove_cache_onnx_converted, inputs=[cache_state_dirname,])
                    with gr.Column():
                        cache_optimized_selected = gr.Textbox(value=None, visible=False)
                        def select_cache_optimized(evt: gr.SelectData, data):
                            return ",".join(data[evt.index[0]])
                        def remove_cache_optimized(dirname: str, s: str):
                            if s == "":
                                return
                            size = s.split(",")
                            shutil.rmtree(os.path.join(opts.onnx_cached_models_path, f"{dirname}-{size[0]}w-{size[1]}h"))
                            log.info(f"Olive processed cache of '{dirname}' is removed: width={size[0]}, height={size[1]}")
                        with gr.Row():
                            cache_list_optimized_headers = ["height", "width"]
                            cache_list_optimized_types = ["str", "str"]
                            cache_list_optimized = gr.Dataframe(None, label="Optimized caches", show_label=True, overflow_row_behaviour='paginate', interactive=False, max_rows=10, headers=cache_list_optimized_headers, datatype=cache_list_optimized_types, type="array")
                            cache_list_optimized.select(fn=select_cache_optimized, inputs=[cache_list_optimized,], outputs=[cache_optimized_selected,])
                        cache_remove_optimized = gr.Button(value="Remove selected cache", visible=False)
                        cache_remove_optimized.click(fn=remove_cache_optimized, inputs=[cache_state_dirname, cache_optimized_selected,])

                    def cache_update_menus(query: str):
                        checkpoint_info = get_closet_checkpoint_match(query)
                        if checkpoint_info is None:
                            log.error(f"Could not find checkpoint object for '{query}'.")
                            return
                        model_name = os.path.basename(os.path.dirname(os.path.dirname(checkpoint_info.path)) if check_diffusers_cache(checkpoint_info.path) else checkpoint_info.path)
                        caches = os.listdir(opts.onnx_cached_models_path)
                        onnx_converted = False
                        optimized_sizes = []
                        for cache in caches:
                            if cache == model_name:
                                onnx_converted = True
                            elif model_name in cache:
                                try:
                                    splitted = cache.split("-")
                                    height = splitted[-1][:-1]
                                    width = splitted[-2][:-1]
                                    optimized_sizes.append((width, height,))
                                except Exception:
                                    pass
                        return (
                            model_name,
                            cache_onnx_converted.update(value="ONNX model cache of this model exists." if onnx_converted else "ONNX model cache of this model does not exist."),
                            cache_remove_onnx_converted.update(visible=onnx_converted),
                            None if len(optimized_sizes) == 0 else optimized_sizes,
                            cache_remove_optimized.update(visible=True),
                        )

                    model_dropdown.change(fn=cache_update_menus, inputs=[model_dropdown,], outputs=[
                        cache_state_dirname,
                        cache_onnx_converted, cache_remove_onnx_converted,
                        cache_list_optimized, cache_remove_optimized,
                    ])

                with gr.TabItem("Customize pass flow", id="pass_flow"):
                    with gr.Tabs(elem_id="tabs_model_type"):
                        with gr.TabItem("Stable Diffusion", id="sd"):
                            sd_config_path = os.path.join(sd_configs_path, "olive", "sd")
                            sd_submodels = os.listdir(sd_config_path)
                            sd_configs: Dict[str, Dict[str, Dict[str, Dict]]] = {}
                            sd_pass_config_components: Dict[str, Dict[str, Dict]] = {}

                            with gr.Tabs(elem_id="tabs_sd_submodel"):
                                def sd_create_change_listener(*args):
                                    def listener(v: Dict):
                                        get_recursively(sd_configs, *args[:-1])[args[-1]] = v
                                    return listener

                                for submodel in sd_submodels:
                                    config: Dict = None
                                    sd_pass_config_components[submodel] = {}
                                    with open(os.path.join(sd_config_path, submodel), "r", encoding="utf-8") as file:
                                        config = json.load(file)
                                    sd_configs[submodel] = config

                                    submodel_name = submodel[:-5]
                                    with gr.TabItem(submodel_name, id=f"sd_{submodel_name}"):
                                        pass_flows = DropdownMulti(label="Pass flow", value=sd_configs[submodel]["pass_flows"][0], choices=sd_configs[submodel]["passes"].keys())
                                        pass_flows.change(fn=sd_create_change_listener(submodel, "pass_flows", 0), inputs=pass_flows)

                                        with gr.Tabs(elem_id=f"tabs_sd_{submodel_name}_pass"):
                                            for pass_name in sd_configs[submodel]["passes"]:
                                                sd_pass_config_components[submodel][pass_name] = {}

                                                with gr.TabItem(pass_name, id=f"sd_{submodel_name}_pass_{pass_name}"):
                                                    config_dict = sd_configs[submodel]["passes"][pass_name]
                                                    pass_type = gr.Dropdown(label="Type", value=config_dict["type"], choices=(x.__name__ for x in tuple(olive_passes.REGISTRY.values())))

                                                    def create_pass_config_change_listener(submodel, pass_name, config_key):
                                                        def listener(value):
                                                            sd_configs[submodel]["passes"][pass_name]["config"][config_key] = value
                                                        return listener

                                                    for config_key, v in getattr(olive_passes, config_dict["type"], olive_passes.Pass)._default_config(accelerator).items(): # pylint: disable=protected-access
                                                        component = None
                                                        if v.type_ == bool:
                                                            component = gr.Checkbox
                                                        elif v.type_ == str:
                                                            component = gr.Textbox
                                                        elif v.type_ == int:
                                                            component = gr.Number
                                                        if component is not None:
                                                            component = component(value=config_dict["config"][config_key] if config_key in config_dict["config"] else v.default_value, label=config_key)
                                                            sd_pass_config_components[submodel][pass_name][config_key] = component
                                                            component.change(fn=create_pass_config_change_listener(submodel, pass_name, config_key), inputs=component)

                                                    pass_type.change(fn=sd_create_change_listener(submodel, "passes", config_key, "type"), inputs=pass_type) # pylint: disable=undefined-loop-variable

                            def sd_save():
                                for k, v in sd_configs.items():
                                    with open(os.path.join(sd_config_path, k), "w", encoding="utf-8") as file:
                                        json.dump(v, file)
                                log.info("Olive: config for SD was saved.")

                            sd_save_button = gr.Button(value="Save")
                            sd_save_button.click(fn=sd_save)

                        with gr.TabItem("Stable Diffusion XL", id="sdxl"):
                            sdxl_config_path = os.path.join(sd_configs_path, "olive", "sdxl")
                            sdxl_submodels = os.listdir(sdxl_config_path)
                            sdxl_configs: Dict[str, Dict[str, Dict[str, Dict]]] = {}
                            sdxl_pass_config_components: Dict[str, Dict[str, Dict]] = {}

                            with gr.Tabs(elem_id="tabs_sdxl_submodel"):
                                def sdxl_create_change_listener(*args):
                                    def listener(v: Dict):
                                        get_recursively(sdxl_configs, *args[:-1])[args[-1]] = v
                                    return listener

                                for submodel in sdxl_submodels:
                                    config: Dict = None
                                    sdxl_pass_config_components[submodel] = {}
                                    with open(os.path.join(sdxl_config_path, submodel), "r", encoding="utf-8") as file:
                                        config = json.load(file)
                                    sdxl_configs[submodel] = config

                                    submodel_name = submodel[:-5]
                                    with gr.TabItem(submodel_name, id=f"sdxl_{submodel_name}"):
                                        pass_flows = DropdownMulti(label="Pass flow", value=sdxl_configs[submodel]["pass_flows"][0], choices=sdxl_configs[submodel]["passes"].keys())
                                        pass_flows.change(fn=sdxl_create_change_listener(submodel, "pass_flows", 0), inputs=pass_flows)

                                        with gr.Tabs(elem_id=f"tabs_sdxl_{submodel_name}_pass"):
                                            for pass_name in sdxl_configs[submodel]["passes"]:
                                                sdxl_pass_config_components[submodel][pass_name] = {}

                                                with gr.TabItem(pass_name, id=f"sdxl_{submodel_name}_pass_{pass_name}"):
                                                    config_dict = sdxl_configs[submodel]["passes"][pass_name]
                                                    pass_type = gr.Dropdown(label="Type", value=sdxl_configs[submodel]["passes"][pass_name]["type"], choices=(x.__name__ for x in tuple(olive_passes.REGISTRY.values())))

                                                    def create_pass_config_change_listener(submodel, pass_name, config_key): # pylint: disable=function-redefined
                                                        def listener(value):
                                                            sdxl_configs[submodel]["passes"][pass_name]["config"][config_key] = value
                                                        return listener

                                                    for config_key, v in getattr(olive_passes, config_dict["type"], olive_passes.Pass)._default_config(accelerator).items(): # pylint: disable=protected-access
                                                        component = None
                                                        if v.type_ == bool:
                                                            component = gr.Checkbox
                                                        elif v.type_ == str:
                                                            component = gr.Textbox
                                                        elif v.type_ == int:
                                                            component = gr.Number
                                                        if component is not None:
                                                            component = component(value=config_dict["config"][config_key] if config_key in config_dict["config"] else v.default_value, label=config_key)
                                                            sdxl_pass_config_components[submodel][pass_name][config_key] = component
                                                            component.change(fn=create_pass_config_change_listener(submodel, pass_name, config_key), inputs=component)
                                                    pass_type.change(fn=sdxl_create_change_listener(submodel, "passes", pass_name, "type"), inputs=pass_type)

                            def sdxl_save():
                                for k, v in sdxl_configs.items():
                                    with open(os.path.join(sdxl_config_path, k), "w", encoding="utf-8") as file:
                                        json.dump(v, file)
                                log.info("Olive: config for SDXL was saved.")

                            sdxl_save_button = gr.Button(value="Save")
                            sdxl_save_button.click(fn=sdxl_save)
    return ui