Reality123b commited on
Commit
27f5740
·
verified ·
1 Parent(s): b414b9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -179
app.py CHANGED
@@ -1,176 +1,159 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- from PIL import Image
4
- import time
5
- import os
6
- import base64
7
- from io import BytesIO
8
-
9
- HF_TOKEN = os.environ.get("HF_TOKEN")
10
-
11
- if not HF_TOKEN:
12
- HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."
13
- else:
14
- HF_TOKEN_ERROR = None
15
-
16
- client = InferenceClient(token=HF_TOKEN)
17
- PROMPT_IMPROVER_MODEL = "HuggingFaceH4/zephyr-7b-beta"
18
-
19
- def improve_prompt(original_prompt):
20
- if HF_TOKEN_ERROR:
21
- raise gr.Error(HF_TOKEN_ERROR)
22
-
23
- try:
24
- system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."
25
- prompt_for_llm = f"""<|system|>
26
- {system_prompt}</s>
27
- <|user|>
28
- Improve this prompt: {original_prompt}
29
- </s>
30
- <|assistant|>
31
- """
32
- improved_prompt = client.text_generation(
33
- prompt=prompt_for_llm,
34
- model=PROMPT_IMPROVER_MODEL,
35
- max_new_tokens=128,
36
- temperature=0.7,
37
- top_p=0.9,
38
- repetition_penalty=1.2,
39
- stop_sequences=["</s>"],
40
-
41
- )
42
-
43
- return improved_prompt.strip()
44
-
45
- except Exception as e:
46
- print(f"Error improving prompt: {e}")
47
- return original_prompt
48
-
49
-
50
- def generate_image(prompt, progress=gr.Progress()):
51
- if HF_TOKEN_ERROR:
52
- raise gr.Error(HF_TOKEN_ERROR)
53
-
54
- progress(0, desc="Improving prompt...")
55
- improved_prompt = improve_prompt(prompt)
56
-
57
- progress(0.2, desc="Sending request to Hugging Face...")
58
- try:
59
- image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")
60
-
61
- if not isinstance(image, Image.Image):
62
- raise Exception(f"Expected a PIL Image, but got: {type(image)}")
63
-
64
- progress(0.8, desc="Processing image...")
65
- time.sleep(0.5)
66
- progress(1.0, desc="Done!")
67
- return image, improved_prompt
68
- except Exception as e:
69
- if "rate limit" in str(e).lower():
70
- error_message = f"Rate limit exceeded. Please try again later. Error: {e}"
71
- else:
72
- error_message = f"An error occurred: {e}"
73
- raise gr.Error(error_message)
74
-
75
- def pil_to_base64(img):
76
- buffered = BytesIO()
77
- img.save(buffered, format="PNG")
78
- img_str = base64.b64encode(buffered.getvalue()).decode()
79
- return f"data:image/png;base64,{img_str}"
80
-
81
- css = """
82
- .container {
83
- max-width: 800px;
84
- margin: auto;
85
- padding: 20px;
86
- border: 1px solid #ddd;
87
- border-radius: 10px;
88
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
89
- }
90
- .title {
91
- text-align: center;
92
- font-size: 2.5em;
93
- margin-bottom: 0.5em;
94
- color: #333;
95
- font-family: 'Arial', sans-serif;
96
- }
97
- .description {
98
- text-align: center;
99
- font-size: 1.1em;
100
- margin-bottom: 1.5em;
101
- color: #555;
102
- }
103
- .input-section, .output-section {
104
- margin-bottom: 1.5em;
105
- }
106
- .output-section img {
107
- display: block;
108
- margin: auto;
109
- max-width: 100%;
110
- height: auto;
111
- border-radius: 8px;
112
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
113
- }
114
-
115
- @keyframes fadeIn {
116
- from { opacity: 0; transform: translateY(20px); }
117
- to { opacity: 1; transform: translateY(0); }
118
- }
119
- .output-section.animate img {
120
- animation: fadeIn 0.8s ease-out;
121
- }
122
-
123
- .submit-button {
124
- display: block;
125
- margin: auto;
126
- padding: 10px 20px;
127
- font-size: 1.1em;
128
- color: white;
129
- background-color: #4CAF50;
130
- border: none;
131
- border-radius: 5px;
132
- cursor: pointer;
133
- transition: background-color 0.3s ease;
134
- }
135
- .submit-button:hover {
136
- background-color: #367c39;
137
- }
138
-
139
- .error-message {
140
- color: red;
141
- text-align: center;
142
- margin-top: 1em;
143
- font-weight: bold;
144
- }
145
- label{
146
- font-weight: bold;
147
- display: block;
148
- margin-bottom: 0.5em;
149
- }
150
-
151
- .improved-prompt-display {
152
- margin-top: 10px;
153
- padding: 8px;
154
- border: 1px solid #ccc;
155
- border-radius: 4px;
156
- background-color: #f9f9f9;
157
- font-style: italic;
158
- color: #444;
159
- }
160
- .download-link {
161
- display: block;
162
- text-align: center;
163
- margin-top: 10px;
164
- color: #4CAF50;
165
- text-decoration: none;
166
- font-weight: bold;
167
- }
168
-
169
- .download-link:hover{
170
- text-decoration: underline;
171
- }
172
- """
173
-
174
 
175
  with gr.Blocks(css=css) as demo:
176
  gr.Markdown(
@@ -191,15 +174,13 @@ with gr.Blocks(css=css) as demo:
191
 
192
  def on_generate_click(prompt):
193
  output_group.elem_classes = ["output-section", "animate"]
194
- image, _ = generate_image(prompt) # Ignore the improved prompt
195
  output_group.elem_classes = ["output-section"]
196
- image_b64 = pil_to_base64(image)
197
- download_html = f'<a class="download-link" href="{image_b64}" download="generated_image.png">Download Image</a>'
198
 
199
- return image, download_html # No improved prompt to return
200
 
201
- generate_button.click(on_generate_click, inputs=prompt_input, outputs=[image_output, download_link])
202
- prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=[image_output, download_link])
203
 
204
  gr.Examples(
205
  [["A dog"],
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ from PIL import Image
4
+ import time
5
+ import os
6
+ import base64
7
+ from io import BytesIO
8
+
9
+ HF_TOKEN = os.environ.get("HF_TOKEN")
10
+
11
+ if not HF_TOKEN:
12
+ HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."
13
+ else:
14
+ HF_TOKEN_ERROR = None
15
+
16
+ client = InferenceClient(token=HF_TOKEN)
17
+ PROMPT_IMPROVER_MODEL = "HuggingFaceH4/zephyr-7b-beta"
18
+
19
+ def improve_prompt(original_prompt):
20
+ if HF_TOKEN_ERROR:
21
+ raise gr.Error(HF_TOKEN_ERROR)
22
+
23
+ try:
24
+ system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."
25
+ prompt_for_llm = f"""<|system|>
26
+ {system_prompt}</s>
27
+ <|user|>
28
+ Improve this prompt: {original_prompt}
29
+ </s>
30
+ <|assistant|>
31
+ """
32
+ improved_prompt = client.text_generation(
33
+ prompt=prompt_for_llm,
34
+ model=PROMPT_IMPROVER_MODEL,
35
+ max_new_tokens=128,
36
+ temperature=0.7,
37
+ top_p=0.9,
38
+ repetition_penalty=1.2,
39
+ stop_sequences=["</s>"],
40
+ )
41
+
42
+ return improved_prompt.strip()
43
+
44
+ except Exception as e:
45
+ print(f"Error improving prompt: {e}")
46
+ return original_prompt
47
+
48
+
49
+ def generate_image(prompt, progress=gr.Progress()):
50
+ if HF_TOKEN_ERROR:
51
+ raise gr.Error(HF_TOKEN_ERROR)
52
+
53
+ progress(0, desc="Improving prompt...")
54
+ improved_prompt = improve_prompt(prompt)
55
+
56
+ progress(0.2, desc="Sending request to Hugging Face...")
57
+ try:
58
+ image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")
59
+
60
+ if not isinstance(image, Image.Image):
61
+ raise Exception(f"Expected a PIL Image, but got: {type(image)}")
62
+
63
+ progress(0.8, desc="Processing image...")
64
+ time.sleep(0.5)
65
+ progress(1.0, desc="Done!")
66
+ return image
67
+ except Exception as e:
68
+ if "rate limit" in str(e).lower():
69
+ error_message = f"Rate limit exceeded. Please try again later. Error: {e}"
70
+ else:
71
+ error_message = f"An error occurred: {e}"
72
+ raise gr.Error(error_message)
73
+
74
+
75
+ def pil_to_base64(img):
76
+ buffered = BytesIO()
77
+ img.save(buffered, format="PNG")
78
+ img_str = base64.b64encode(buffered.getvalue()).decode()
79
+ return f"data:image/png;base64,{img_str}"
80
+
81
+
82
+ css = """
83
+ body {
84
+ background-color: #f4f4f4;
85
+ font-family: 'Arial', sans-serif;
86
+ }
87
+
88
+ .container {
89
+ max-width: 900px;
90
+ margin: auto;
91
+ padding: 30px;
92
+ border-radius: 10px;
93
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1);
94
+ background-color: white;
95
+ }
96
+
97
+ .title {
98
+ text-align: center;
99
+ font-size: 3em;
100
+ margin-bottom: 0.5em;
101
+ color: #3a3a3a;
102
+ }
103
+
104
+ .input-section {
105
+ background-color: #e3f7fc;
106
+ border-radius: 8px;
107
+ padding: 15px;
108
+ }
109
+
110
+ .output-section {
111
+ background-color: #f0f0f0;
112
+ border-radius: 8px;
113
+ padding: 15px;
114
+ }
115
+
116
+ .output-section img {
117
+ max-width: 100%;
118
+ height: auto;
119
+ border-radius: 8px;
120
+ }
121
+
122
+ .submit-button {
123
+ background-color: #007BFF;
124
+ border: none;
125
+ border-radius: 5px;
126
+ color: white;
127
+ padding: 12px 20px;
128
+ cursor: pointer;
129
+ transition: background-color 0.3s ease, transform 0.2s ease;
130
+ }
131
+
132
+ .submit-button:hover {
133
+ background-color: #0056b3;
134
+ transform: scale(1.05);
135
+ }
136
+
137
+ .error-message {
138
+ color: red;
139
+ text-align: center;
140
+ font-weight: bold;
141
+ }
142
+
143
+ .label {
144
+ font-weight: bold;
145
+ }
146
+
147
+ .download-link {
148
+ color: #007BFF;
149
+ font-weight: bold;
150
+ text-decoration: none;
151
+ }
152
+
153
+ .download-link:hover {
154
+ text-decoration: underline;
155
+ }
156
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  with gr.Blocks(css=css) as demo:
159
  gr.Markdown(
 
174
 
175
  def on_generate_click(prompt):
176
  output_group.elem_classes = ["output-section", "animate"]
177
+ image = generate_image(prompt) # Ignore the improved prompt
178
  output_group.elem_classes = ["output-section"]
 
 
179
 
180
+ return image # Return only the generated image
181
 
182
+ generate_button.click(on_generate_click, inputs=prompt_input, outputs=image_output)
183
+ prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=image_output)
184
 
185
  gr.Examples(
186
  [["A dog"],