OMilosh commited on
Commit
7887362
·
verified ·
1 Parent(s): f7adb40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -8
app.py CHANGED
@@ -6,23 +6,28 @@ import random
6
  from diffusers import DiffusionPipeline
7
  import torch
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
 
 
 
 
23
 
24
  # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
 
26
  prompt,
27
  negative_prompt,
28
  seed,
@@ -33,6 +38,8 @@ def infer(
33
  num_inference_steps,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
 
 
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
 
@@ -82,6 +89,11 @@ with gr.Blocks(css=css) as demo:
82
  result = gr.Image(label="Result", show_label=False)
83
 
84
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
85
  negative_prompt = gr.Text(
86
  label="Negative prompt",
87
  max_lines=1,
@@ -138,6 +150,7 @@ with gr.Blocks(css=css) as demo:
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
140
  inputs=[
 
141
  prompt,
142
  negative_prompt,
143
  seed,
 
6
  from diffusers import DiffusionPipeline
7
  import torch
8
 
9
+ MAX_SEED = np.iinfo(np.int32).max
10
+ MAX_IMAGE_SIZE = 1024
11
 
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
13
 
14
+ available_models = [
15
+ "stabilityai/sdxl-turbo",
16
+ "CompVis/stable-diffusion-v1-4",
17
+ "runwayml/stable-diffusion-v1-5",
18
+ ]
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
+ def init_model(model_repo_id):
24
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
25
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
26
+ return pipe.to(device)
27
 
28
  # @spaces.GPU #[uncomment to use ZeroGPU]
29
  def infer(
30
+ model_repo_id,
31
  prompt,
32
  negative_prompt,
33
  seed,
 
38
  num_inference_steps,
39
  progress=gr.Progress(track_tqdm=True),
40
  ):
41
+ pipe = init_model(model_repo_id)
42
+
43
  if randomize_seed:
44
  seed = random.randint(0, MAX_SEED)
45
 
 
89
  result = gr.Image(label="Result", show_label=False)
90
 
91
  with gr.Accordion("Advanced Settings", open=False):
92
+ model_repo_id = gr.Dropdown(available_models,
93
+ multiselect=False,
94
+ label="Model",
95
+ info="Choose models for generation")
96
+
97
  negative_prompt = gr.Text(
98
  label="Negative prompt",
99
  max_lines=1,
 
150
  triggers=[run_button.click, prompt.submit],
151
  fn=infer,
152
  inputs=[
153
+ model_repo_id,
154
  prompt,
155
  negative_prompt,
156
  seed,