seawolf2357 commited on
Commit
011e303
·
verified ·
1 Parent(s): f86d8cc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -0
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ import random
6
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
7
+ import torch
8
+ from transformers import pipeline as transformers_pipeline
9
+ import re
10
+
11
+ # Device selection for image generation (GPU if available)
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # Stable Diffusion XL pipeline
15
+ pipe = StableDiffusionXLPipeline.from_pretrained(
16
+ "votepurchase/waiNSFWIllustrious_v120",
17
+ torch_dtype=torch.float16,
18
+ variant="fp16",
19
+ use_safetensors=True,
20
+ )
21
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
22
+ pipe.to(device)
23
+
24
+ # Force modules to fp16 for memory efficiency
25
+ pipe.text_encoder.to(torch.float16)
26
+ pipe.text_encoder_2.to(torch.float16)
27
+ pipe.vae.to(torch.float16)
28
+ pipe.unet.to(torch.float16)
29
+
30
+ # Korean → English translator (CPU only)
31
+ translator = transformers_pipeline(
32
+ "translation",
33
+ model="Helsinki-NLP/opus-mt-ko-en",
34
+ device=-1, # -1 forces CPU
35
+ )
36
+
37
+ MAX_SEED = np.iinfo(np.int32).max
38
+ MAX_IMAGE_SIZE = 1216
39
+ korean_regex = re.compile("[\uac00-\ud7af]+")
40
+
41
+ def maybe_translate(text: str) -> str:
42
+ """Translate Korean text to English if Korean characters are detected."""
43
+ if korean_regex.search(text):
44
+ translation = translator(text, max_length=256, clean_up_tokenization_spaces=True)
45
+ return translation[0]["translation_text"]
46
+ return text
47
+
48
+ @spaces.GPU
49
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
50
+ prompt = maybe_translate(prompt)
51
+ negative_prompt = maybe_translate(negative_prompt)
52
+
53
+ if len(prompt.split()) > 60:
54
+ print("Warning: Prompt may be too long and will be truncated by the model")
55
+
56
+ if randomize_seed:
57
+ seed = random.randint(0, MAX_SEED)
58
+
59
+ generator = torch.Generator(device=device).manual_seed(seed)
60
+
61
+ try:
62
+ output_image = pipe(
63
+ prompt=prompt,
64
+ negative_prompt=negative_prompt,
65
+ guidance_scale=guidance_scale,
66
+ num_inference_steps=num_inference_steps,
67
+ width=width,
68
+ height=height,
69
+ generator=generator,
70
+ ).images[0]
71
+ return output_image
72
+ except RuntimeError as e:
73
+ print(f"Error during generation: {e}")
74
+ error_img = Image.new("RGB", (width, height), color=(0, 0, 0))
75
+ return error_img
76
+
77
+ # Custom styling
78
+ css = """
79
+ body {background: #0f0f0f; color: #fafafa; font-family: 'Noto Sans', sans-serif;}
80
+ #col-container {margin: 0 auto; max-width: 640px;}
81
+ .gr-button {background: #2563eb; color: #ffffff; border-radius: 8px;}
82
+ #prompt-box textarea {font-size: 1.1rem; height: 3rem;}
83
+ """
84
+
85
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
86
+ gr.Markdown(
87
+ """
88
+ ## 🖌️ Stable Diffusion XL Playground
89
+ Generate high quality illustrations with a single prompt.
90
+ **Tip:** Write in Korean or English. Korean will be translated automatically.
91
+ """
92
+ )
93
+
94
+ with gr.Column(elem_id="col-container"):
95
+ with gr.Row():
96
+ prompt = gr.Text(
97
+ label="Prompt",
98
+ elem_id="prompt-box",
99
+ show_label=False,
100
+ max_lines=1,
101
+ placeholder="Enter your prompt (60 words max)",
102
+ )
103
+ run_button = gr.Button("Generate", scale=0)
104
+
105
+ result = gr.Image(label="", show_label=False)
106
+
107
+ examples = gr.Examples(
108
+ examples=[
109
+ ["어두운 재즈 바에서 담배 연기를 내뿜는 미스터리한 팜파탈, 성인용 애니메이션 스타일"],
110
+ ["노출이 강조된 드레스를 입은 고딕 뱀파이어 여왕, 드라마틱 조명, 성인 애니 아트"],
111
+ ["은은한 조명의 온천에서 두 연인이 마주 서 있는 관능적 장면, 성인용 애니메이션"],
112
+ ["네온이 빛나는 사이버펑크 클럽 무대에서 도발적인 의상을 입은 댄서, 성인 애니 스타일"],
113
+ ["달빛 아래 요염한 마법사가 주문을 외우는 판타지 장면, 성인용 애니 일러스트"],
114
+ ],
115
+ inputs=[prompt],
116
+ )
117
+
118
+ with gr.Accordion("Advanced Settings", open=False):
119
+ negative_prompt = gr.Text(
120
+ label="Negative prompt",
121
+ max_lines=1,
122
+ placeholder="Enter a negative prompt",
123
+ value="nsfw, low quality, watermark, signature",
124
+ )
125
+
126
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
127
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
128
+
129
+ with gr.Row():
130
+ width = gr.Slider(
131
+ label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024
132
+ )
133
+ height = gr.Slider(
134
+ label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024
135
+ )
136
+
137
+ with gr.Row():
138
+ guidance_scale = gr.Slider(
139
+ label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=7
140
+ )
141
+ num_inference_steps = gr.Slider(
142
+ label="Inference steps", minimum=1, maximum=28, step=1, value=28
143
+ )
144
+
145
+ run_button.click(
146
+ fn=infer,
147
+ inputs=[
148
+ prompt,
149
+ negative_prompt,
150
+ seed,
151
+ randomize_seed,
152
+ width,
153
+ height,
154
+ guidance_scale,
155
+ num_inference_steps,
156
+ ],
157
+ outputs=[result],
158
+ )
159
+
160
+ demo.queue().launch()