rafaldembski commited on
Commit
27817fc
·
verified ·
1 Parent(s): d768756

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -7,12 +7,13 @@ import torch
7
  from torch.autograd import Variable
8
  from torchvision import transforms
9
  import torch.nn.functional as F
10
- import gdown
11
  import warnings
12
  warnings.filterwarnings("ignore")
13
 
14
- os.system("git clone https://github.com/xuebinqin/DIS")
15
- os.system("mv DIS/IS-Net/* .")
 
 
16
 
17
  # project imports
18
  from data_loader_cache import normalize, im_reader, im_preprocess
@@ -21,9 +22,10 @@ from models import *
21
  # Helpers
22
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
 
24
- # Download official weights
25
  if not os.path.exists("saved_models"):
26
  os.mkdir("saved_models")
 
27
  os.system("mv isnet.pth saved_models/")
28
 
29
  class GOSNormalize(object):
@@ -193,22 +195,20 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as demo:
193
  article = gr.Markdown(translations["en"]["article"])
194
 
195
  with gr.Row():
196
- prompt = gr.Image(type='filepath')
197
- result = gr.Image(label="Segmented Image", show_label=False)
198
-
 
 
 
 
 
 
199
  language_selector.change(
200
  fn=change_language,
201
  inputs=language_selector,
202
  outputs=[title, description, article],
 
203
  )
204
 
205
- gr.Interface(
206
- fn=inference,
207
- inputs=prompt,
208
- outputs=[result],
209
- title=title,
210
- description=description,
211
- article=article,
212
- allow_flagging='never',
213
- cache_examples=False,
214
- ).launch()
 
7
  from torch.autograd import Variable
8
  from torchvision import transforms
9
  import torch.nn.functional as F
 
10
  import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
+ # Sprawdzenie, czy katalog DIS istnieje przed klonowaniem
14
+ if not os.path.exists("DIS"):
15
+ os.system("git clone https://github.com/xuebinqin/DIS")
16
+ os.system("mv DIS/IS-Net/* .")
17
 
18
  # project imports
19
  from data_loader_cache import normalize, im_reader, im_preprocess
 
22
  # Helpers
23
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
 
25
+ # Sprawdzenie, czy katalog z modelami istnieje przed przeniesieniem
26
  if not os.path.exists("saved_models"):
27
  os.mkdir("saved_models")
28
+ if not os.path.exists("saved_models/isnet.pth"):
29
  os.system("mv isnet.pth saved_models/")
30
 
31
  class GOSNormalize(object):
 
195
  article = gr.Markdown(translations["en"]["article"])
196
 
197
  with gr.Row():
198
+ with gr.Column(elem_id="col-container"):
199
+ gr.Image("logo.png", elem_id="logo-img", show_label=False, show_share_button=False, show_download_button=False)
200
+ inputs = gr.Image(type='filepath', label="Wybierz obraz")
201
+ outputs = [gr.Image(label="Wynik (z przezroczystością)"), gr.Image(label="Maska")]
202
+
203
+ run_button = gr.Button("Segmentuj", scale=0)
204
+
205
+ run_button.click(fn=inference, inputs=inputs, outputs=outputs)
206
+
207
  language_selector.change(
208
  fn=change_language,
209
  inputs=language_selector,
210
  outputs=[title, description, article],
211
+ api_name=False,
212
  )
213
 
214
+ demo.launch()