Nikhil-4brains commited on
Commit
01ba08b
Β·
verified Β·
1 Parent(s): 1426bb9

Upload 6 files

Browse files
Files changed (6) hide show
  1. .gitattributes +53 -35
  2. README.md +13 -12
  3. app.py +240 -0
  4. log.py +78 -0
  5. ominicontrol.py +158 -0
  6. requirements.txt +10 -0
.gitattributes CHANGED
@@ -1,35 +1,53 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ OminiControl/assets/cartoon_boy.png filter=lfs diff=lfs merge=lfs -text
37
+ OminiControl/assets/clock.jpg filter=lfs diff=lfs merge=lfs -text
38
+ OminiControl/assets/demo/demo_this_is_omini_control.jpg filter=lfs diff=lfs merge=lfs -text
39
+ OminiControl/assets/demo/dreambooth_res.jpg filter=lfs diff=lfs merge=lfs -text
40
+ OminiControl/assets/demo/monalisa_omini.jpg filter=lfs diff=lfs merge=lfs -text
41
+ OminiControl/assets/demo/scene_variation.jpg filter=lfs diff=lfs merge=lfs -text
42
+ OminiControl/assets/demo/try_on.jpg filter=lfs diff=lfs merge=lfs -text
43
+ OminiControl/assets/monalisa.jpg filter=lfs diff=lfs merge=lfs -text
44
+ OminiControl/assets/rc_car.jpg filter=lfs diff=lfs merge=lfs -text
45
+ OminiControl/assets/room_corner.jpg filter=lfs diff=lfs merge=lfs -text
46
+ OminiControl/assets/tshirt.jpg filter=lfs diff=lfs merge=lfs -text
47
+ OminiControl/assets/vase_hq.jpg filter=lfs diff=lfs merge=lfs -text
48
+ examples/breakingbad.jpg filter=lfs diff=lfs merge=lfs -text
49
+ examples/DistractedBoyfriend.webp filter=lfs diff=lfs merge=lfs -text
50
+ examples/doge.jpg filter=lfs diff=lfs merge=lfs -text
51
+ examples/oiiai.png filter=lfs diff=lfs merge=lfs -text
52
+ examples/PulpFiction.jpg filter=lfs diff=lfs merge=lfs -text
53
+ examples/steve.webp filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: Ghibli Art
3
- emoji: πŸ“Š
4
- colorFrom: indigo
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.25.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: OminiControl Art
3
+ emoji: 🎨
4
+ colorFrom: green
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.23.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: unknown
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from ominicontrol import generate_image, vote_feedback
4
+ import os
5
+
6
+
7
+ USE_ZERO_GPU = os.environ.get("USE_ZERO_GPU", "0") == "1"
8
+
9
+ css = """
10
+ .inputPanel {
11
+ width: 320px;
12
+ display: flex;
13
+ align-items: center;
14
+ }
15
+ .outputPanel {
16
+ display: flex;
17
+ align-items: center;
18
+ }
19
+ .hint {
20
+ font-size: 14px;
21
+ color: #777;
22
+ # border: 1px solid #ccc;
23
+ padding: 4px;
24
+ border-radius: 5px;
25
+ # background-color: #efefef;
26
+ }
27
+ """
28
+
29
+ header = """
30
+ # 🎨 OminiControl Art
31
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
32
+ <a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
33
+ <a href="https://huggingface.co/spaces/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/πŸ€—OminiControl-Demo-ffbd45.svg" alt="HuggingFace"></a>
34
+ <a href="https://github.com/Yuanshi9815/OminiControl"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
35
+ </div>
36
+
37
+ ***OminiControl Art*** distills the artistic style of [GPT-4o](https://openai.com/index/introducing-4o-image-generation/) into the [FLUX.1](https://blackforestlabs.ai/) model, building on the foundation of [OminiControl](https://github.com/Yuanshi9815/OminiControl)✨.
38
+ Enjoy playing around! 🌈
39
+ """
40
+
41
+
42
+ def style_transfer(image, style):
43
+ return image
44
+
45
+
46
+ styles = [
47
+ "Studio Ghibli",
48
+ "Irasutoya Illustration",
49
+ "The Simpsons",
50
+ "Snoopy",
51
+ ]
52
+
53
+
54
+ def gradio_interface():
55
+ with gr.Blocks(css=css) as demo:
56
+ gr.Markdown(header)
57
+ with gr.Row(equal_height=False):
58
+ with gr.Column(variant="panel", elem_classes="inputPanel"):
59
+ original_image = gr.Image(
60
+ type="pil",
61
+ label="Condition Image",
62
+ width=400,
63
+ height=400,
64
+ )
65
+ style = gr.Radio(
66
+ styles,
67
+ label="🎨 Select Style",
68
+ value=styles[0],
69
+ )
70
+ # Advanced settings
71
+ with gr.Accordion(
72
+ "βš™οΈ Advanced Settings", open=False
73
+ ) as advanced_settings:
74
+ inference_mode = gr.Radio(
75
+ ["High Quality", "Fast"],
76
+ value="High Quality",
77
+ label="Generating Mode",
78
+ )
79
+ image_ratio = gr.Radio(
80
+ ["Auto", "Square(1:1)", "Portrait(2:3)", "Landscape(3:2)"],
81
+ label="Image Ratio",
82
+ value="Auto",
83
+ )
84
+ use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
85
+ seed = gr.Number(
86
+ label="Seed",
87
+ value=42,
88
+ visible=(not use_random_seed.value),
89
+ )
90
+ use_random_seed.change(
91
+ lambda x: gr.update(visible=(not x)),
92
+ use_random_seed,
93
+ seed,
94
+ show_progress="hidden",
95
+ )
96
+ image_guidance = gr.Slider(
97
+ label="Image Guidance",
98
+ minimum=1.1,
99
+ maximum=5,
100
+ value=1.5,
101
+ step=0.1,
102
+ )
103
+ steps = gr.Slider(
104
+ label="Steps",
105
+ minimum=10,
106
+ maximum=50,
107
+ value=20,
108
+ step=1,
109
+ )
110
+ inference_mode.change(
111
+ lambda x: gr.update(interactive=(x == "High Quality")),
112
+ inference_mode,
113
+ image_guidance,
114
+ show_progress="hidden",
115
+ )
116
+
117
+ btn = gr.Button("Generate Image", variant="primary")
118
+
119
+ with gr.Accordion("🏞️ Examples", open=True) as advanced_settings:
120
+ examples = gr.Examples(
121
+ examples=[
122
+ ["examples/DistractedBoyfriend.webp", styles[0]],
123
+ ["examples/steve.webp", styles[0]],
124
+ ["examples/oiiai.png", styles[1]],
125
+ ["examples/doge.jpg", styles[1]],
126
+ ["examples/breakingbad.jpg", styles[2]],
127
+ ["examples/PulpFiction.jpg", styles[3]],
128
+ ],
129
+ inputs=[original_image, style],
130
+ )
131
+
132
+ with gr.Column(elem_classes="outputPanel"):
133
+ output_image = gr.Image(
134
+ type="pil",
135
+ width=600,
136
+ height=600,
137
+ label="Output Image",
138
+ interactive=False,
139
+ sources=None,
140
+ )
141
+ inference_id = gr.Textbox(
142
+ visible=False,
143
+ interactive=False,
144
+ )
145
+
146
+ # Feedback buttons
147
+ with gr.Column(visible=False) as feedback:
148
+ gr.Markdown(
149
+ """
150
+ Your feedback improves the model! Please let us know how you feel about the generated image.
151
+ """,
152
+ )
153
+ with gr.Row() as feedback_buttons:
154
+ upvote = gr.Button("πŸ‘ I like it", variant="primary")
155
+ downvote = gr.Button("πŸ‘Ž It looks bad")
156
+
157
+ def feedback_func(feedback):
158
+ def func(inputs):
159
+ print(f"Feedback: {feedback}, Inference ID: {inputs}")
160
+ vote_feedback(log_id=inputs, feedback=feedback)
161
+ # Here you can add your feedback logging logic
162
+ return gr.update(visible=False)
163
+
164
+ return func
165
+
166
+ upvote.click(feedback_func("1"), inference_id, feedback)
167
+ downvote.click(feedback_func("0"), inference_id, feedback)
168
+
169
+ inference_id.change(
170
+ lambda x: gr.update(visible=True), output_image, feedback
171
+ )
172
+
173
+ hint = gr.Markdown(
174
+ """
175
+ <div style="text-align: center; width: 100%;">
176
+ <b>Note: The selected style is in beta testing.</b> Feel free to try a few more times to get a the better result.
177
+ </div>
178
+ """,
179
+ visible=False,
180
+ )
181
+ style.change(
182
+ lambda x: gr.update(visible=x in styles[1:]),
183
+ style,
184
+ hint,
185
+ )
186
+
187
+ # with gr.Row():
188
+ btn.click(
189
+ fn=infer,
190
+ inputs=[
191
+ style,
192
+ original_image,
193
+ inference_mode,
194
+ image_guidance,
195
+ image_ratio,
196
+ use_random_seed,
197
+ seed,
198
+ steps,
199
+ ],
200
+ outputs=[
201
+ output_image,
202
+ inference_id,
203
+ ],
204
+ )
205
+
206
+ return demo
207
+
208
+
209
+ def infer(
210
+ style,
211
+ original_image,
212
+ inference_mode,
213
+ image_guidance,
214
+ image_ratio,
215
+ use_random_seed,
216
+ seed,
217
+ steps,
218
+ ):
219
+ print(
220
+ f"Style: {style}, Inference Mode: {inference_mode}, Image Guidance: {image_guidance}, Image Ratio: {image_ratio}, Use Random Seed: {use_random_seed}, Seed: {seed}"
221
+ )
222
+ result_image, inference_id = generate_image(
223
+ image=original_image,
224
+ style=style,
225
+ inference_mode=inference_mode,
226
+ image_guidance=image_guidance,
227
+ image_ratio=image_ratio,
228
+ use_random_seed=use_random_seed,
229
+ seed=seed,
230
+ steps=steps,
231
+ )
232
+ return result_image, inference_id
233
+
234
+
235
+ if USE_ZERO_GPU:
236
+ infer = spaces.GPU(infer)
237
+
238
+ if __name__ == "__main__":
239
+ demo = gradio_interface()
240
+ demo.launch(server_name="0.0.0.0", ssr_mode=False)
log.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+ import uuid
3
+ import time
4
+ import os
5
+
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+
10
+ MAX_PIXELS = 2048
11
+
12
+ AWS_BUCKET_NAME = os.environ.get("AWS_BUCKET_NAME", "")
13
+ AWS_INFERENCE_LOG_TABLE = os.environ.get("AWS_INFERENCE_LOG_TABLE", "")
14
+ AWS_FEEDBACK_LOG_TABLE = os.environ.get("AWS_FEEDBACK_LOG_TABLE", "")
15
+
16
+
17
+ AWS_REGION = os.environ.get("AWS_REGION", "")
18
+ AWS_ACCESS_ID = os.environ.get("AWS_ACCESS_ID", "")
19
+ AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", "")
20
+
21
+
22
+ aws_cfg = {
23
+ "aws_access_key_id": AWS_ACCESS_ID,
24
+ "aws_secret_access_key": AWS_ACCESS_KEY,
25
+ "region_name": AWS_REGION,
26
+ }
27
+
28
+ s3_client = boto3.client("s3", **aws_cfg)
29
+ dynamodb = boto3.resource("dynamodb", **aws_cfg)
30
+
31
+ inference_log = dynamodb.Table(AWS_INFERENCE_LOG_TABLE)
32
+ feedback_log = dynamodb.Table(AWS_FEEDBACK_LOG_TABLE)
33
+
34
+
35
+ def get_metadata():
36
+ return {
37
+ "_id": uuid.uuid4().hex,
38
+ "created_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
39
+ }
40
+
41
+
42
+ def insert_log(table_type: str, data: dict):
43
+ assert table_type in ["inference", "feedback"], "Invalid table type"
44
+ table = inference_log if table_type == "inference" else feedback_log
45
+ metadata = get_metadata()
46
+ response = table.put_item(
47
+ Item={
48
+ **data,
49
+ **metadata,
50
+ }
51
+ )
52
+ return response, metadata["_id"]
53
+
54
+
55
+ # Example usage:
56
+ # insert_log("inference", {"data": "test"})
57
+ # insert_log("feedback", {"data": "test"})
58
+
59
+
60
+ def get_image_obj(image: Image) -> BytesIO:
61
+ image.thumbnail((MAX_PIXELS, MAX_PIXELS))
62
+ image_obj = BytesIO()
63
+ image.save(image_obj, format="WEBP")
64
+ image_obj.seek(0)
65
+ return image_obj
66
+
67
+
68
+ def log_image(image: Image) -> str:
69
+ metadata = get_metadata()
70
+ image_obj = get_image_obj(image)
71
+ s3_key = f"images/{metadata['_id']}.webp"
72
+ s3_client.upload_fileobj(image_obj, AWS_BUCKET_NAME, s3_key)
73
+ return metadata["_id"]
74
+
75
+
76
+ # Example usage:
77
+ # image = Image.open("examples/doge.jpg")
78
+ # log_image(image)
ominicontrol.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines import FluxPipeline
3
+ from OminiControl.src.flux.condition import Condition
4
+ from PIL import Image
5
+ import random
6
+
7
+ from OminiControl.src.flux.generate import generate, seed_everything
8
+
9
+ from log import insert_log, log_image
10
+
11
+ print("Loading model...")
12
+ pipe = FluxPipeline.from_pretrained(
13
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
14
+ )
15
+ pipe = pipe.to("cuda")
16
+
17
+ pipe.unload_lora_weights()
18
+
19
+ pipe.load_lora_weights(
20
+ "Yuanshi/OminiControlStyle",
21
+ weight_name=f"v0/ghibli.safetensors",
22
+ adapter_name="ghibli",
23
+ )
24
+ pipe.load_lora_weights(
25
+ "Yuanshi/OminiControlStyle",
26
+ weight_name=f"v0/irasutoya.safetensors",
27
+ adapter_name="irasutoya",
28
+ )
29
+ pipe.load_lora_weights(
30
+ "Yuanshi/OminiControlStyle",
31
+ weight_name=f"v0/simpsons.safetensors",
32
+ adapter_name="simpsons",
33
+ )
34
+ pipe.load_lora_weights(
35
+ "Yuanshi/OminiControlStyle",
36
+ weight_name=f"v0/snoopy.safetensors",
37
+ adapter_name="snoopy",
38
+ )
39
+
40
+
41
+ def generate_image(
42
+ image,
43
+ style,
44
+ inference_mode,
45
+ image_guidance,
46
+ image_ratio,
47
+ steps,
48
+ use_random_seed,
49
+ seed,
50
+ ):
51
+ condition_id = log_image(image)
52
+
53
+ # Prepare Condition
54
+ def resize(img, factor=16):
55
+ w, h = img.size
56
+ new_w, new_h = w // factor * factor, h // factor * factor
57
+ padding_w, padding_h = (w - new_w) // 2, (h - new_h) // 2
58
+ img = img.crop((padding_w, padding_h, new_w + padding_w, new_h + padding_h))
59
+ return img
60
+
61
+ # Set Adapter
62
+ activate_adapter_name = {
63
+ "Studio Ghibli": "ghibli",
64
+ "Irasutoya Illustration": "irasutoya",
65
+ "The Simpsons": "simpsons",
66
+ "Snoopy": "snoopy",
67
+ }[style]
68
+ pipe.set_adapters(activate_adapter_name)
69
+
70
+ factor = 512 / max(image.size)
71
+ image = resize(
72
+ image.resize(
73
+ (int(image.size[0] * factor), int(image.size[1] * factor)),
74
+ Image.LANCZOS,
75
+ )
76
+ )
77
+ delta = -image.size[0] // 16
78
+ condition = Condition(
79
+ "subject",
80
+ # activate_adapter_name,
81
+ image,
82
+ position_delta=(0, delta),
83
+ )
84
+
85
+ # Prepare seed
86
+ if use_random_seed:
87
+ seed = random.randint(0, 2**32 - 1)
88
+ seed_everything(seed)
89
+
90
+ # Image guidance scale
91
+ image_guidance = 1.0 if inference_mode == "Fast" else image_guidance
92
+
93
+ # Output size
94
+ if image_ratio == "Auto":
95
+ r = image.size[0] / image.size[1]
96
+ ratio = min([0.67, 1, 1.5], key=lambda x: abs(x - r))
97
+ else:
98
+ ratio = {
99
+ "Square(1:1)": 1,
100
+ "Portrait(2:3)": 0.67,
101
+ "Landscape(3:2)": 1.5,
102
+ }[image_ratio]
103
+ width, height = {
104
+ 0.67: (640, 960),
105
+ 1: (640, 640),
106
+ 1.5: (960, 640),
107
+ }[ratio]
108
+
109
+ print(
110
+ f"Image Ratio: {image_ratio}, Inference Mode: {inference_mode}, Image Guidance: {image_guidance}, Seed: {seed}, Steps: {steps}, Size: {width}x{height}"
111
+ )
112
+ # Generate
113
+ result_img = generate(
114
+ pipe,
115
+ prompt="",
116
+ conditions=[condition],
117
+ num_inference_steps=steps,
118
+ width=width,
119
+ height=height,
120
+ image_guidance_scale=image_guidance,
121
+ default_lora=True,
122
+ max_sequence_length=32,
123
+ ).images[0]
124
+ # result_img = image
125
+ result_id = log_image(result_img)
126
+
127
+ log_data = {
128
+ "condition": condition_id,
129
+ "result": result_id,
130
+ "prompt": "",
131
+ "inference_mode": inference_mode,
132
+ "image_guidance_scale": image_guidance,
133
+ "seed": seed,
134
+ "steps": steps,
135
+ "style": style,
136
+ "width": width,
137
+ "height": height,
138
+ }
139
+ log_data = {k: str(v) for k, v in log_data.items()}
140
+
141
+ _, log_id = insert_log("inference", log_data)
142
+
143
+ print(f"Image log ID: {log_id}")
144
+
145
+ return result_img, log_id
146
+
147
+
148
+ def vote_feedback(
149
+ log_id,
150
+ feedback,
151
+ ):
152
+ log_data = {
153
+ "log_id": log_id,
154
+ "feedback": feedback,
155
+ }
156
+ log_data = {k: str(v) for k, v in log_data.items()}
157
+
158
+ insert_log("feedback", log_data)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ diffusers
3
+ peft
4
+ opencv-python
5
+ protobuf
6
+ sentencepiece
7
+ jupyter
8
+ torchao
9
+
10
+ boto3