Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from transformers import AutoConfig, AutoModelForCausalLM,
|
4 |
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
from janus.utils.io import load_pil_images
|
6 |
from PIL import Image
|
@@ -11,290 +11,58 @@ from Upsample import RealESRGAN
|
|
11 |
import spaces # Import spaces for ZeroGPU compatibility
|
12 |
import re
|
13 |
|
14 |
-
# ๋ฒ์ญ ํ์ดํ๋ผ์ธ ์ด๊ธฐํ (ํ๊ธ โ ์์ด)
|
15 |
-
translator = translation_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
return translation
|
23 |
-
except Exception as e:
|
24 |
-
print(f"Translation error: {e}")
|
25 |
-
return prompt
|
26 |
-
return prompt
|
27 |
|
28 |
-
|
29 |
-
model_path = "deepseek-ai/Janus-Pro-7B"
|
30 |
-
config = AutoConfig.from_pretrained(model_path)
|
31 |
-
language_config = config.language_config
|
32 |
-
language_config._attn_implementation = 'eager'
|
33 |
-
vl_gpt = AutoModelForCausalLM.from_pretrained(
|
34 |
-
model_path,
|
35 |
-
language_config=language_config,
|
36 |
-
trust_remote_code=True
|
37 |
-
)
|
38 |
-
if torch.cuda.is_available():
|
39 |
-
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
40 |
-
else:
|
41 |
-
vl_gpt = vl_gpt.to(torch.float16)
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False)
|
50 |
-
|
51 |
-
@torch.inference_mode()
|
52 |
-
@spaces.GPU(duration=120)
|
53 |
-
def multimodal_understanding(image, question, seed, top_p, temperature):
|
54 |
-
# (์๋ต) ๊ธฐ์กด multimodal ์ดํด ํจ์ ๋ด์ฉ ๊ทธ๋๋ก...
|
55 |
-
torch.cuda.empty_cache()
|
56 |
-
torch.manual_seed(seed)
|
57 |
-
np.random.seed(seed)
|
58 |
-
torch.cuda.manual_seed(seed)
|
59 |
-
|
60 |
-
conversation = [
|
61 |
-
{
|
62 |
-
"role": "<|User|>",
|
63 |
-
"content": f"<image_placeholder>\n{question}",
|
64 |
-
"images": [image],
|
65 |
-
},
|
66 |
-
{"role": "<|Assistant|>", "content": ""},
|
67 |
-
]
|
68 |
-
|
69 |
-
pil_images = [Image.fromarray(image)] if isinstance(image, np.ndarray) else [image]
|
70 |
-
prepare_inputs = vl_chat_processor(
|
71 |
-
conversations=conversation, images=pil_images, force_batchify=True
|
72 |
-
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
73 |
-
|
74 |
-
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
attention_mask=prepare_inputs.attention_mask,
|
79 |
-
pad_token_id=tokenizer.eos_token_id,
|
80 |
-
bos_token_id=tokenizer.bos_token_id,
|
81 |
-
eos_token_id=tokenizer.eos_token_id,
|
82 |
-
max_new_tokens=512,
|
83 |
-
do_sample=False if temperature == 0 else True,
|
84 |
-
use_cache=True,
|
85 |
-
temperature=temperature,
|
86 |
-
top_p=top_p,
|
87 |
-
)
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
parallel_size: int = 5, cfg_weight: float = 5,
|
94 |
-
image_token_num_per_image: int = 576, patch_size: int = 16):
|
95 |
-
torch.cuda.empty_cache()
|
96 |
-
|
97 |
-
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
98 |
-
for i in range(parallel_size * 2):
|
99 |
-
tokens[i, :] = input_ids
|
100 |
-
if i % 2 != 0:
|
101 |
-
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
102 |
-
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
103 |
-
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
|
104 |
-
|
105 |
-
pkv = None
|
106 |
-
for i in range(image_token_num_per_image):
|
107 |
-
with torch.no_grad():
|
108 |
-
outputs = vl_gpt.language_model.model(
|
109 |
-
inputs_embeds=inputs_embeds,
|
110 |
-
use_cache=True,
|
111 |
-
past_key_values=pkv
|
112 |
-
)
|
113 |
-
pkv = outputs.past_key_values
|
114 |
-
hidden_states = outputs.last_hidden_state
|
115 |
-
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
116 |
-
logit_cond = logits[0::2, :]
|
117 |
-
logit_uncond = logits[1::2, :]
|
118 |
-
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
119 |
-
probs = torch.softmax(logits / temperature, dim=-1)
|
120 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
121 |
-
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
122 |
-
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
123 |
-
|
124 |
-
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
125 |
-
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
135 |
-
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
136 |
-
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
137 |
-
visual_img[:, :, :] = dec
|
138 |
-
return visual_img
|
139 |
-
|
140 |
-
@torch.inference_mode()
|
141 |
-
@spaces.GPU(duration=120)
|
142 |
-
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
|
143 |
-
# ๋ฒ์ญ: ์
๋ ฅ ํ๋กฌํํธ์ ํ๊ธ์ด ํฌํจ๋์ด ์์ผ๋ฉด ์์ด๋ก ๋ณํ
|
144 |
-
prompt = translate_if_korean(prompt)
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
{'role': '<|Assistant|>', 'content': ''}]
|
158 |
-
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
159 |
-
conversations=messages,
|
160 |
-
sft_format=vl_chat_processor.sft_format,
|
161 |
-
system_prompt=''
|
162 |
-
)
|
163 |
-
text = text + vl_chat_processor.image_start_tag
|
164 |
-
input_ids = torch.LongTensor(tokenizer.encode(text))
|
165 |
-
output, patches = generate(
|
166 |
-
input_ids,
|
167 |
-
width // 16 * 16,
|
168 |
-
height // 16 * 16,
|
169 |
-
cfg_weight=guidance,
|
170 |
-
parallel_size=parallel_size,
|
171 |
-
temperature=t2i_temperature
|
172 |
-
)
|
173 |
-
images = unpack(
|
174 |
-
patches,
|
175 |
-
width // 16 * 16,
|
176 |
-
height // 16 * 16,
|
177 |
-
parallel_size=parallel_size
|
178 |
-
)
|
179 |
-
|
180 |
-
stime = time.time()
|
181 |
-
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|
182 |
-
print(f'upsample time: {time.time() - stime}')
|
183 |
-
return ret_images
|
184 |
-
|
185 |
-
@spaces.GPU(duration=60)
|
186 |
-
def image_upsample(img: Image.Image) -> Image.Image:
|
187 |
-
if img is None:
|
188 |
-
raise Exception("Image not uploaded")
|
189 |
-
width, height = img.size
|
190 |
-
if width >= 5000 or height >= 5000:
|
191 |
-
raise Exception("The image is too large.")
|
192 |
-
global sr_model
|
193 |
-
result = sr_model.predict(img.convert('RGB'))
|
194 |
-
return result
|
195 |
-
|
196 |
-
# Custom CSS for a sleek, modern and highly readable interface
|
197 |
-
custom_css = """
|
198 |
-
body {
|
199 |
-
background: #f0f2f5;
|
200 |
-
font-family: 'Segoe UI', sans-serif;
|
201 |
-
color: #333;
|
202 |
-
}
|
203 |
-
h1, h2, h3 {
|
204 |
-
font-weight: 600;
|
205 |
-
}
|
206 |
-
.gradio-container {
|
207 |
-
padding: 20px;
|
208 |
-
}
|
209 |
-
header {
|
210 |
-
text-align: center;
|
211 |
-
padding: 20px;
|
212 |
-
margin-bottom: 20px;
|
213 |
-
}
|
214 |
-
header h1 {
|
215 |
-
font-size: 3em;
|
216 |
-
color: #2c3e50;
|
217 |
-
}
|
218 |
-
.gr-button {
|
219 |
-
background-color: #3498db !important;
|
220 |
-
color: #fff !important;
|
221 |
-
border: none !important;
|
222 |
-
padding: 10px 20px !important;
|
223 |
-
border-radius: 5px !important;
|
224 |
-
font-size: 1em !important;
|
225 |
-
}
|
226 |
-
.gr-button:hover {
|
227 |
-
background-color: #2980b9 !important;
|
228 |
-
}
|
229 |
-
.gr-input, .gr-slider, .gr-number, .gr-textbox {
|
230 |
-
border-radius: 5px;
|
231 |
-
}
|
232 |
-
.gr-gallery-item {
|
233 |
-
border-radius: 10px;
|
234 |
-
overflow: hidden;
|
235 |
-
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
236 |
-
}
|
237 |
-
"""
|
238 |
-
|
239 |
-
# Gradio Interface
|
240 |
-
with gr.Blocks(css=custom_css, title="Multimodal & T2I") as demo:
|
241 |
-
with gr.Column(variant="panel"):
|
242 |
-
gr.Markdown("<header><h1>Chat With Janus-Pro-7B</h1></header>")
|
243 |
-
|
244 |
-
with gr.Tabs():
|
245 |
-
with gr.TabItem("Multimodal Understanding"):
|
246 |
-
gr.Markdown("### Chat with Images")
|
247 |
-
with gr.Row():
|
248 |
-
image_input = gr.Image(label="Upload Image", type="numpy")
|
249 |
-
with gr.Column():
|
250 |
-
question_input = gr.Textbox(label="Question", placeholder="Enter your question about the image here...", lines=4)
|
251 |
-
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
252 |
-
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="Top_p")
|
253 |
-
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
|
254 |
-
understanding_button = gr.Button("Chat", elem_id="understanding-button")
|
255 |
-
understanding_output = gr.Textbox(label="Response", lines=6)
|
256 |
-
with gr.Accordion("Examples", open=False):
|
257 |
-
gr.Examples(
|
258 |
-
label="Multimodal Understanding Examples",
|
259 |
-
examples=[
|
260 |
-
["explain this meme", "doge.png"]
|
261 |
-
],
|
262 |
-
inputs=[question_input, image_input],
|
263 |
-
)
|
264 |
-
understanding_button.click(
|
265 |
-
multimodal_understanding,
|
266 |
-
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
267 |
-
outputs=understanding_output,
|
268 |
-
)
|
269 |
-
|
270 |
-
with gr.TabItem("Text-to-Image Generation"):
|
271 |
-
gr.Markdown("### Generate Images from Text")
|
272 |
-
with gr.Row():
|
273 |
-
prompt_input = gr.Textbox(label="Prompt", placeholder="Enter detailed prompt for image generation...", lines=4)
|
274 |
-
with gr.Row():
|
275 |
-
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234)
|
276 |
-
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
277 |
-
t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature")
|
278 |
-
generation_button = gr.Button("Generate Images", elem_id="generation-button")
|
279 |
-
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
|
280 |
-
with gr.Accordion("Examples", open=False):
|
281 |
-
gr.Examples(
|
282 |
-
label="Text-to-Image Examples",
|
283 |
-
examples=[
|
284 |
-
"Master shifu racoon wearing drip attire as a street gangster.",
|
285 |
-
"The face of a beautiful girl",
|
286 |
-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
287 |
-
"A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting, immortal, fluffy, shiny mane, petals, fairyism, unreal engine 5 and Octane Render, highly detailed, photorealistic, cinematic, natural colors.",
|
288 |
-
"๊ณ ์์ด๊ฐ ์ฐ์ฃผ๋ณต์ ์
๊ณ ๋ฌ์ ์๋ ๋ชจ์ต"
|
289 |
-
],
|
290 |
-
inputs=prompt_input,
|
291 |
-
)
|
292 |
-
generation_button.click(
|
293 |
-
fn=generate_image,
|
294 |
-
inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
|
295 |
-
outputs=image_output,
|
296 |
-
)
|
297 |
-
|
298 |
-
gr.Markdown("<footer style='text-align:center; padding:20px 0;'>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
|
299 |
|
300 |
-
demo.launch(
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoModelForVision2Seq, pipeline
|
4 |
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
from janus.utils.io import load_pil_images
|
6 |
from PIL import Image
|
|
|
11 |
import spaces # Import spaces for ZeroGPU compatibility
|
12 |
import re
|
13 |
|
|
|
|
|
14 |
|
15 |
+
# Load Janus Pro models for vision and text tasks
|
16 |
+
vision_model = AutoModelForVision2Seq.from_pretrained("deepseek-ai/janus-pro-7b", torch_dtype=torch.float16, device_map="auto")
|
17 |
+
text_model = AutoModelForCausalLM.from_pretrained("deepseek-ai/janus-pro-7b", torch_dtype=torch.float16, device_map="auto")
|
18 |
+
processor = AutoProcessor.from_pretrained("deepseek-ai/janus-pro-7b")
|
19 |
+
image_pipe = pipeline("text-to-image", model="deepseek-ai/janus-pro-7b")
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
last_uploaded_image = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
def detect_image_request(user_input):
|
24 |
+
image_keywords = ["generate an image", "create an image", "show me a picture", "draw", "visualize","generate image",
|
25 |
+
"image generation", "get me an image", "get an image", "need an image", "need image",]
|
26 |
+
return any(re.search(keyword, user_input, re.IGNORECASE) for keyword in image_keywords)
|
27 |
|
28 |
+
def chatbot(history, image=None, user_input=""):
|
29 |
+
global last_uploaded_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
if image:
|
32 |
+
last_uploaded_image = image # Store the latest uploaded image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
if detect_image_request(user_input):
|
35 |
+
image = image_pipe(user_input)
|
36 |
+
history.append((user_input, "[Generated Image]"))
|
37 |
+
return history, "", image[0]["image"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
if last_uploaded_image:
|
40 |
+
inputs = processor(images=last_uploaded_image, return_tensors="pt").to("cuda")
|
41 |
+
output = vision_model.generate(**inputs)
|
42 |
+
response = processor.decode(output[0], skip_special_tokens=True)
|
43 |
+
else:
|
44 |
+
response = text_model.generate(user_input)
|
45 |
+
response = processor.decode(response[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
history.append((user_input, response))
|
48 |
+
return history, "", None
|
49 |
+
|
50 |
+
def reset_chat():
|
51 |
+
global last_uploaded_image
|
52 |
+
last_uploaded_image = None
|
53 |
+
return [], ""
|
54 |
+
|
55 |
+
with gr.Blocks() as demo:
|
56 |
+
gr.Markdown("# Janus Pro Chatbot with Vision & Image Generation")
|
57 |
+
chatbot_interface = gr.Chatbot()
|
58 |
+
with gr.Row():
|
59 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
60 |
+
text_input = gr.Textbox(label="Type your message")
|
61 |
+
send_button = gr.Button("Send")
|
62 |
+
reset_button = gr.Button("Reset Chat")
|
63 |
+
image_output = gr.Image(label="Generated Image")
|
64 |
|
65 |
+
send_button.click(chatbot, [chatbot_interface, image_input, text_input], [chatbot_interface, text_input, image_output])
|
66 |
+
reset_button.click(reset_chat, [], [chatbot_interface, text_input])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
demo.launch()
|