cwhuh commited on
Commit
586b09a
ยท
1 Parent(s): ec3cccf

add : refinement logic

Browse files
Files changed (3) hide show
  1. app.py +18 -5
  2. llm_wrapper.py +107 -0
  3. prompt.json +7 -0
app.py CHANGED
@@ -7,12 +7,19 @@ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, Autoe
7
  from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
 
 
10
  from huggingface_hub import hf_hub_download
11
  from safetensors.torch import load_file
12
  import subprocess
13
 
14
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
15
 
 
 
 
 
 
 
16
  dtype = torch.bfloat16
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
@@ -38,9 +45,15 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
38
  if randomize_seed:
39
  seed = random.randint(0, MAX_SEED)
40
  generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
41
 
42
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
43
- prompt=prompt,
44
  guidance_scale=guidance_scale,
45
  num_inference_steps=num_inference_steps,
46
  width=width,
@@ -52,9 +65,9 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
52
  yield img, seed
53
 
54
  examples = [
55
- "๋กœ์ผ“์— ํƒ€๊ณ  ์žˆ๋Š” ํฌ๋‹‰์Šค",
56
- "์ผ๋ ‰๊ธฐํƒ€๋ฅผ ๋“ค๊ณ  ์žˆ๋Š” ํฌ๋‹‰์Šค",
57
- "์ปดํ“จํ„ฐ๊ณตํ•™์„ ๊ณต๋ถ€์ค‘์ธ ํฌ๋‹‰์Šค",
58
  ]
59
 
60
  css="""
@@ -67,7 +80,7 @@ css="""
67
  with gr.Blocks(css=css) as demo:
68
 
69
  with gr.Column(elem_id="col-container"):
70
- gr.Markdown(f"""# [POSTECH] PONIX Generator
71
  [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
72
  """)
73
 
 
7
  from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
 
10
+ from llm_wrapper import run_gemini
11
  from huggingface_hub import hf_hub_download
12
  from safetensors.torch import load_file
13
  import subprocess
14
 
15
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
16
 
17
+
18
+ from pydantic import BaseModel
19
+
20
+ class RefinedPrompt(BaseModel):
21
+ prompt: str
22
+
23
  dtype = torch.bfloat16
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
 
45
  if randomize_seed:
46
  seed = random.randint(0, MAX_SEED)
47
  generator = torch.Generator().manual_seed(seed)
48
+
49
+ refined_prompt = run_gemini(
50
+ target_prompt=prompt,
51
+ prompt_in_path="prompt.json",
52
+ output_structure=RefinedPrompt,
53
+ )
54
 
55
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
56
+ prompt=refined_prompt.prompt,
57
  guidance_scale=guidance_scale,
58
  num_inference_steps=num_inference_steps,
59
  width=width,
 
65
  yield img, seed
66
 
67
  examples = [
68
+ "๊ธฐ๊ณ„๊ณตํ•™๊ณผ(๋กœ์ผ“) ํฌ๋‹‰์Šค",
69
+ "๋ฐ”์ด์˜ฌ๋ฆฐ์„ ์—ฐ์ฃผํ•˜๋Š” ํฌ๋‹‰์Šค",
70
+ "๋ฌผ๋ฆฌํ•™์„ ์—ฐ๊ตฌํ•˜๋Š” ํฌ๋‹‰์Šค",
71
  ]
72
 
73
  css="""
 
80
  with gr.Blocks(css=css) as demo:
81
 
82
  with gr.Column(elem_id="col-container"):
83
+ gr.Markdown(f"""# [POSTECH] PONIX Generator ๐ŸŒŠ
84
  [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
85
  """)
86
 
llm_wrapper.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import requests, os, json, time
5
+
6
+ from google import genai
7
+
8
+ prompt_base_path = "src/llm_wrapper/prompt"
9
+ client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
10
+
11
+
12
+ def encode_image(image_source):
13
+ """
14
+ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๊ฐ€ URL์ด๋“  ๋กœ์ปฌ ํŒŒ์ผ์ด๋“  Pillow Image ๊ฐ์ฒด์ด๋“  ๋™์ผํ•˜๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜.
15
+ ์ด๋ฏธ์ง€๋ฅผ ์—ด์–ด google.genai.types.Part ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
16
+ Pillow์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์— ๋Œ€ํ•ด์„œ๋Š” ์˜ˆ์™ธ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค.
17
+ """
18
+ try:
19
+ # ์ด๋ฏธ Pillow ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
20
+ if isinstance(image_source, Image.Image):
21
+ image = image_source
22
+ else:
23
+ # URL์—์„œ ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ
24
+ if isinstance(image_source, str) and (
25
+ image_source.startswith("http://")
26
+ or image_source.startswith("https://")
27
+ ):
28
+ response = requests.get(image_source)
29
+ image = Image.open(BytesIO(response.content))
30
+ # ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ์—ด๊ธฐ
31
+ else:
32
+ image = Image.open(image_source)
33
+
34
+ # ์ด๋ฏธ์ง€ ํฌ๋งท์ด None์ธ ๊ฒฝ์šฐ (๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ๋“ฑ)
35
+ if image.format is None:
36
+ image_format = "JPEG"
37
+ else:
38
+ image_format = image.format
39
+
40
+ # ์ด๋ฏธ์ง€ ํฌ๋งท์ด ์ง€์›๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ ์˜ˆ์™ธ ๋ฐœ์ƒ
41
+ if image_format not in Image.registered_extensions().values():
42
+ raise ValueError(f"Unsupported image format: {image_format}.")
43
+
44
+ buffered = BytesIO()
45
+ # PIL์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์ด๋‚˜ ๋‹ค์–‘ํ•œ ์ฑ„๋„์„ RGB๋กœ ๋ณ€ํ™˜ ํ›„ ์ €์žฅ
46
+ if image.mode in ("RGBA", "P", "CMYK"): # RGBA, ํŒ”๋ ˆํŠธ, CMYK ๋“ฑ์€ RGB๋กœ ๋ณ€ํ™˜
47
+ image = image.convert("RGB")
48
+ image.save(buffered, format="JPEG")
49
+
50
+ return genai.types.Part.from_bytes(data=buffered.getvalue(), mime_type="image/jpeg")
51
+
52
+ except requests.exceptions.RequestException as e:
53
+ raise ValueError(f"Failed to download the image from URL: {e}")
54
+ except IOError as e:
55
+ raise ValueError(f"Failed to process the image file: {e}")
56
+ except ValueError as e:
57
+ raise ValueError(e)
58
+
59
+
60
+ def run_gemini(
61
+ target_prompt: str,
62
+ prompt_in_path: str,
63
+ output_structure,
64
+ img_in_data: str = None,
65
+ model: str = "gemini-2.0-flash",
66
+ ) -> str:
67
+ """
68
+ GEMINI API๋ฅผ ๋™๊ธฐ ๋ฐฉ์‹์œผ๋กœ ํ˜ธ์ถœํ•˜์—ฌ ๋ฌธ์ž์—ด ์‘๋‹ต์„ ๋ฐ›์Šต๋‹ˆ๋‹ค.
69
+ retry ๋…ผ๋ฆฌ๋Š” ์ œ๊ฑฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
70
+ """
71
+ with open(os.path.join(prompt_base_path, prompt_in_path), "r", encoding="utf-8") as file:
72
+ prompt_dict = json.load(file)
73
+
74
+ system_prompt = prompt_dict["system_prompt"]
75
+ user_prompt_head = prompt_dict["user_prompt"]["head"]
76
+ user_prompt_tail = prompt_dict["user_prompt"]["tail"]
77
+
78
+ user_prompt_text = "\n".join([user_prompt_head, target_prompt, user_prompt_tail])
79
+ input_content = [user_prompt_text]
80
+
81
+ if img_in_data is not None:
82
+ encoded_image = encode_image(img_in_data)
83
+ input_content.append(encoded_image)
84
+
85
+ logging.info("Requested API for chat completion response (sync call)...")
86
+ start_time = time.time()
87
+
88
+ # ๋™๊ธฐ ๋ฐฉ์‹: client.models.generate_content(...)
89
+ chat_completion = client.models.generate_content(
90
+ model=model,
91
+ contents=input_content,
92
+ config={
93
+ "system_instruction": system_prompt,
94
+ "response_mime_type": "application/json",
95
+ "response_schema": output_structure,
96
+ }
97
+ )
98
+
99
+ chat_output = chat_completion.parsed
100
+ input_token = chat_completion.usage_metadata.prompt_token_count
101
+ output_token = chat_completion.usage_metadata.candidates_token_count
102
+ pricing = input_token / 1000000 * 0.1 * 1500 + output_token / 1000000 * 0.7 * 1500
103
+
104
+ logging.info(
105
+ f"[GEMINI] Request completed (sync). Time taken: {time.time()-start_time:.2f}s / Pricing(KRW): {pricing:.2f}"
106
+ )
107
+ return chat_output, chat_completion
prompt.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "system_prompt": "**์—ญํ• (Role)**\n\n๋‹น์‹ (์‹œ์Šคํ…œ)์€ ์‚ฌ์šฉ์ž๊ฐ€ ์ž…๋ ฅํ•œ ์š”๊ตฌ์‚ฌํ•ญ(ํ•œ๊ตญ์–ด ๋ฌธ์žฅ)์„ ๋ฐ›์•„,\nTextual Inversion์ด ์ ์šฉ๋œ ํ† ํฐ(\"<s0><s1><s2> plush bird\")์ด ๋ฐ˜๋“œ์‹œ ํฌํ•จ๋œ ์˜์–ด ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.\n์ถ”๊ฐ€๋กœ ๋ฐฐ๊ฒฝยท์ƒํ™ฉยท์Šคํƒ€์ผ์„ ํ’๋ถ€ํ•˜๊ฒŒ ๋ฌ˜์‚ฌํ•˜์—ฌ, ์‹ค์ œ ํ…์ŠคํŠธ ํˆฌ ์ด๋ฏธ์ง€ ๋ชจ๋ธ์— ๋„ฃ๊ธฐ๋งŒ ํ•˜๋ฉด ์›ํ•˜๋Š” ์žฅ๋ฉด์ด ์ƒ์„ฑ๋˜๋„๋ก ๋„์›€์„ ์ค๋‹ˆ๋‹ค.\n\n\n**์ฃผ์š” ๊ทœ์น™(Rules)**\n\n**\"ํฌ๋‹‰์Šค\"**๋ผ๋Š” ๋‹จ์–ด๊ฐ€ ์š”๊ตฌ์‚ฌํ•ญ์— ๋“ฑ์žฅํ•˜๋ฉด, ์ด๋ฅผ **<s0><s1><s2> plush bird**๋กœ ์น˜ํ™˜ํ•œ๋‹ค.\nํ•ญ์ƒ ํ”„๋กฌํ”„ํŠธ ๋ฌธ์žฅ ์•ž๋ถ€๋ถ„์— **photo of <s0><s1><s2> plush bird**๋ฅผ ํฌํ•จํ•œ๋‹ค.\n์‚ฌ์šฉ์ž๊ฐ€ ์›ํ•˜๋Š” **์žฅ๋ฉด(๋ฐฐ๊ฒฝ, ํ™˜๊ฒฝ, ์ƒํ™ฉ)**์„ ์˜์–ด๋กœ ์ž์„ธํžˆ ๋ฌ˜์‚ฌํ•œ๋‹ค.\n์Šคํƒ€์ผ(e.g., hyper-realistic, cinematic lighting, 8k resolution, ultra high quality, ๋“ฑ)์„ ์ ์ ˆํžˆ ์ถ”๊ฐ€ํ•ด ๊ณ ํ€„๋ฆฌํ‹ฐ ์ด๋ฏธ์ง€๋ฅผ ์œ ๋„ํ•œ๋‹ค.\n์ถœ๋ ฅ์€ ๋‹จ์ผ ๋ฌธ์ž์—ด(๋˜๋Š” ์—ฌ๋Ÿฌ ์ค„) ํ˜•ํƒœ๋กœ ์˜์–ด ๋ฌธ์žฅ ์œ„์ฃผ๋กœ ์ž‘์„ฑํ•œ๋‹ค.\n\n\n\n**์˜ˆ์‹œ(Examples)**\n\nInput 1\n์‚ฌ์šฉ์ž: \"๊ธฐ๊ณ„๊ณตํ•™๊ณผ(๋กœ์ผ“) ํฌ๋‹‰์Šค\"\n์‹œ์Šคํ…œ ๋ณ€ํ™˜:\nphoto of <s0><s1><s2> plush bird \nwearing an astronaut suit and space helmet\ninside a spacecraft cockpit during flight, \nsurrounded by control panels and navigation systems,\nblinking lights and monitoring screens,\nEarth visible through the spacecraft window in background,\nhyper-realistic details, cinematic lighting, 8k resolution, \nultra high quality photograph, \nhigh-tech space environment, adventurous atmosphere\n\n\nInput 2\n์‚ฌ์šฉ์ž: \"๋ฐ”์ด์˜ฌ๋ฆฐ์„ ์—ฐ์ฃผํ•˜๋Š” ํฌ๋‹‰์Šค\"\n์‹œ์Šคํ…œ ๋ณ€ํ™˜:\nphoto of <s0><s1><s2> plush bird \nwearing an elegant black tailcoat,\ncrisp white dress shirt with bow tie,\nformal concert attire,\nholding a violin in playing position,\nin a grand concert hall with ornate architecture,\nwarm ambient lighting from chandeliers,\norchestra members visible in background,\nsheet music on stand nearby,\naudience in formal attire visible,\npolished wooden stage floor,\nconductor's podium visible in background,\nhyper-realistic details, warm classical lighting, 8k resolution,\nultra high quality photograph,\nprofessional classical concert environment, performance moment\n\n\nInput 3\n์‚ฌ์šฉ์ž: \"๋ฌผ๋ฆฌํ•™์„ ์—ฐ๊ตฌํ•˜๋Š” ํฌ๋‹‰์Šค\"\n์‹œ์Šคํ…œ ๋ณ€ํ™˜:\nphoto of <s0><s1><s2> plush bird \nwearing a lab coat and safety glasses,\ninside a physics laboratory,\nconducting experiments with quantum physics equipment,\nsurrounded by equations written on whiteboards,\noperating particle accelerator models,\nmeasuring devices and scientific instruments visible,\nhyper-realistic details, dramatic academic lighting, 8k resolution,\nultra high quality photograph,\nscientific environment, discovery atmosphere\n\n\n**์ถœ๋ ฅ ํ˜•์‹(Output Format)**\n\n์ตœ์ข… ์ถœ๋ ฅ์€ ์˜์–ด ํ…์ŠคํŠธ๋กœ ๋œ ํ•˜๋‚˜์˜ ํ”„๋กฌํ”„ํŠธ ๋ฌธ์žฅ(๋˜๋Š” ์—ฌ๋Ÿฌ ์ค„)์ด๋ฉฐ,\n๋ฐ˜๋“œ์‹œ <s0><s1><s2> plush bird๊ฐ€ ๋“ค์–ด ์žˆ์–ด์•ผ ํ•จ.\n์ƒํ™ฉ์— ๋”ฐ๋ผ ๋ฐฐ๊ฒฝยท๋””ํ…Œ์ผยท์กฐ๋ช…ยทํ•ด์ƒ๋„๋ฅผ ๋‹ค์–‘ํ•œ ํ˜•์šฉ์‚ฌ๋กœ ํ’๋ถ€ํžˆ ๊ธฐ์ˆ ํ•ด ์ค€๋‹ค.",
3
+ "user_prompt": {
4
+ "head": "",
5
+ "tail": ""
6
+ }
7
+ }