jamino30 commited on
Commit
d35a711
·
verified ·
1 Parent(s): fbacbe7

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +33 -64
  2. requirements.txt +0 -1
app.py CHANGED
@@ -9,7 +9,6 @@ import torch.optim as optim
9
  import torchvision.models as models
10
  import numpy as np
11
  import gradio as gr
12
- from gradio_imageslider import ImageSlider
13
  from safetensors.torch import load_file
14
  from huggingface_hub import hf_hub_download
15
 
@@ -54,66 +53,38 @@ for style_name, style_img_path in style_options.items():
54
  cached_style_features[style_name] = style_features
55
 
56
  @spaces.GPU(duration=30)
57
- def run(content_image, style_name, style_strength=10, optim_name='AdamW'):
58
- yield [None] * 3
59
  content_img, original_size = preprocess_img(content_image, img_size)
60
  content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
61
  content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
 
62
 
63
- if optim_name == 'Adam':
64
- optim_caller = torch.optim.Adam
65
- elif optim_name == 'AdamW':
66
- optim_caller = torch.optim.AdamW
67
- else:
68
- optim_caller = torch.optim.LBFGS
69
 
70
  print('-'*15)
71
  print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est
72
  print('STYLE:', style_name)
73
  print('CONTENT IMG SIZE:', original_size)
74
  print('STYLE STRENGTH:', style_strength, f'(lr={lrs[style_strength-1]:.3f})')
75
-
76
- style_features = cached_style_features[style_name]
77
 
78
  st = time.time()
79
-
80
- if device == 'cuda':
81
- stream_all = torch.cuda.Stream()
82
- stream_bg = torch.cuda.Stream()
83
-
84
- def run_inference_cuda(apply_to_background, stream):
85
- with torch.cuda.stream(stream):
86
- return run_inference(apply_to_background)
87
-
88
- def run_inference(apply_to_background):
89
- return inference(
90
- model=model,
91
- sod_model=sod_model,
92
- content_image=content_img,
93
- content_image_norm=content_img_normalized,
94
- style_features=style_features,
95
- lr=lrs[style_strength-1],
96
- apply_to_background=apply_to_background,
97
- optim_caller=optim_caller
98
- )
99
-
100
- with ThreadPoolExecutor() as executor:
101
- if device == 'cuda':
102
- future_all = executor.submit(run_inference_cuda, False, stream_all)
103
- future_bg = executor.submit(run_inference_cuda, True, stream_bg)
104
- else:
105
- future_all = executor.submit(run_inference, False)
106
- future_bg = executor.submit(run_inference, True)
107
- generated_img_all = future_all.result()
108
- generated_img_bg = future_bg.result()
109
-
110
  et = time.time()
111
  print('TIME TAKEN:', et-st)
112
 
113
- yield (
114
- (content_image, postprocess_img(generated_img_all, original_size)),
115
- (content_image, postprocess_img(generated_img_bg, original_size))
116
- )
117
 
118
  def set_slider(value):
119
  return gr.update(value=value)
@@ -133,6 +104,8 @@ with gr.Blocks(css=css) as demo:
133
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
134
  with gr.Group():
135
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=10, info='Higher values add artistic flair, lower values add a realistic feel.')
 
 
136
  with gr.Accordion(label='Advanced Options', open=False):
137
  optim_dropdown = gr.Radio(choices=['Adam', 'AdamW', 'L-BFGS'], label='Optimizer', value='AdamW', type='value')
138
  submit_button = gr.Button('Submit', variant='primary')
@@ -147,34 +120,30 @@ with gr.Blocks(css=css) as demo:
147
  )
148
 
149
  with gr.Column():
150
- output_image_all = ImageSlider(position=0.15, label='Styled Image', type='pil', interactive=False, show_download_button=False)
151
- download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
152
- with gr.Group():
153
- output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
154
- download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)
155
 
156
- def save_image(img_tuple1, img_tuple2):
157
- filename1, filename2 = 'generated-all.jpg', 'generated-bg.jpg'
158
- img_tuple1[1].save(filename1)
159
- img_tuple2[1].save(filename2)
160
- return filename1, filename2
161
 
162
  submit_button.click(
163
- fn=lambda: [gr.update(visible=False) for _ in range(2)],
164
- outputs=[download_button_1, download_button_2]
165
  )
166
 
167
  submit_button.click(
168
  fn=run,
169
- inputs=[content_image, style_dropdown, style_strength_slider, optim_dropdown],
170
- outputs=[output_image_all, output_image_background]
171
  ).then(
172
  fn=save_image,
173
- inputs=[output_image_all, output_image_background],
174
- outputs=[download_button_1, download_button_2]
175
  ).then(
176
- fn=lambda: [gr.update(visible=True) for _ in range(2)],
177
- outputs=[download_button_1, download_button_2]
178
  )
179
 
180
  demo.queue = False
 
9
  import torchvision.models as models
10
  import numpy as np
11
  import gradio as gr
 
12
  from safetensors.torch import load_file
13
  from huggingface_hub import hf_hub_download
14
 
 
53
  cached_style_features[style_name] = style_features
54
 
55
  @spaces.GPU(duration=30)
56
+ def run(content_image, style_name, style_strength=10, optim_name='AdamW', apply_to_background=False):
57
+ yield None
58
  content_img, original_size = preprocess_img(content_image, img_size)
59
  content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
60
  content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
61
+ style_features = cached_style_features[style_name]
62
 
63
+ if optim_name == 'Adam': optim_caller = torch.optim.Adam
64
+ elif optim_name == 'AdamW': optim_caller = torch.optim.AdamW
65
+ elif optim_name == 'L-BFGS': optim_caller = torch.optim.LBFGS
 
 
 
66
 
67
  print('-'*15)
68
  print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est
69
  print('STYLE:', style_name)
70
  print('CONTENT IMG SIZE:', original_size)
71
  print('STYLE STRENGTH:', style_strength, f'(lr={lrs[style_strength-1]:.3f})')
 
 
72
 
73
  st = time.time()
74
+ generated_img = inference(
75
+ model=model,
76
+ sod_model=sod_model,
77
+ content_image=content_img,
78
+ content_image_norm=content_img_normalized,
79
+ style_features=style_features,
80
+ lr=lrs[style_strength-1],
81
+ apply_to_background=apply_to_background,
82
+ optim_caller=optim_caller
83
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  et = time.time()
85
  print('TIME TAKEN:', et-st)
86
 
87
+ yield postprocess_img(generated_img, original_size)
 
 
 
88
 
89
  def set_slider(value):
90
  return gr.update(value=value)
 
104
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
105
  with gr.Group():
106
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=10, info='Higher values add artistic flair, lower values add a realistic feel.')
107
+ with gr.Group():
108
+ apply_to_background_checkbox = gr.Checkbox(label='Apply styling to background only', value=False)
109
  with gr.Accordion(label='Advanced Options', open=False):
110
  optim_dropdown = gr.Radio(choices=['Adam', 'AdamW', 'L-BFGS'], label='Optimizer', value='AdamW', type='value')
111
  submit_button = gr.Button('Submit', variant='primary')
 
120
  )
121
 
122
  with gr.Column():
123
+ output_image = gr.Image(label='Output', type='pil', interactive=False, show_download_button=False)
124
+ download_button = gr.DownloadButton(label='Download Image', visible=False)
 
 
 
125
 
126
+ def save_image(img):
127
+ filename = 'generated.jpg'
128
+ img.save(filename)
129
+ return filename
 
130
 
131
  submit_button.click(
132
+ fn=lambda: gr.update(visible=False),
133
+ outputs=download_button
134
  )
135
 
136
  submit_button.click(
137
  fn=run,
138
+ inputs=[content_image, style_dropdown, style_strength_slider, optim_dropdown, apply_to_background_checkbox],
139
+ outputs=output_image
140
  ).then(
141
  fn=save_image,
142
+ inputs=output_image,
143
+ outputs=download_button
144
  ).then(
145
+ fn=lambda: gr.update(visible=True),
146
+ outputs=download_button
147
  )
148
 
149
  demo.queue = False
requirements.txt CHANGED
@@ -5,7 +5,6 @@ safetensors
5
  huggingface_hub
6
  pillow
7
  gradio
8
- gradio_imageslider
9
  spaces
10
  tqdm
11
  tensorboard
 
5
  huggingface_hub
6
  pillow
7
  gradio
 
8
  spaces
9
  tqdm
10
  tensorboard