hysts HF Staff commited on
Commit
a84e7ee
·
1 Parent(s): f8fdc71
Files changed (5) hide show
  1. .pre-commit-config.yaml +59 -34
  2. README.md +1 -1
  3. app.py +37 -59
  4. requirements.txt +2 -2
  5. style.css +4 -0
.pre-commit-config.yaml CHANGED
@@ -1,36 +1,61 @@
1
  exclude: ^patch
2
  repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.12.0
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- - repo: https://github.com/google/yapf
33
- rev: v0.32.0
34
- hooks:
35
- - id: yapf
36
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  exclude: ^patch
2
  repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.6.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ["--fix=lf"]
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.7.5
19
+ hooks:
20
+ - id: docformatter
21
+ args: ["--in-place"]
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.13.2
24
+ hooks:
25
+ - id: isort
26
+ args: ["--profile", "black"]
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v1.10.0
29
+ hooks:
30
+ - id: mypy
31
+ args: ["--ignore-missing-imports"]
32
+ additional_dependencies:
33
+ [
34
+ "types-python-slugify",
35
+ "types-requests",
36
+ "types-PyYAML",
37
+ "types-pytz",
38
+ ]
39
+ - repo: https://github.com/psf/black
40
+ rev: 24.4.2
41
+ hooks:
42
+ - id: black
43
+ language_version: python3.10
44
+ args: ["--line-length", "119"]
45
+ - repo: https://github.com/kynan/nbstripout
46
+ rev: 0.7.1
47
+ hooks:
48
+ - id: nbstripout
49
+ args:
50
+ [
51
+ "--extra-keys",
52
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
53
+ ]
54
+ - repo: https://github.com/nbQA-dev/nbQA
55
+ rev: 1.8.5
56
+ hooks:
57
+ - id: nbqa-black
58
+ - id: nbqa-pyupgrade
59
+ args: ["--py37-plus"]
60
+ - id: nbqa-isort
61
+ args: ["--float-to-top"]
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📊
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
 
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
app.py CHANGED
@@ -15,29 +15,25 @@ import torch
15
  import torch.nn as nn
16
  from huggingface_hub import hf_hub_download
17
 
18
- if os.environ.get('SYSTEM') == 'spaces':
19
- with open('patch') as f:
20
- subprocess.run(shlex.split('patch -p1'),
21
- cwd='stylegan2-pytorch',
22
- stdin=f)
23
  if not torch.cuda.is_available():
24
- with open('patch-cpu') as f:
25
- subprocess.run(shlex.split('patch -p1'),
26
- cwd='stylegan2-pytorch',
27
- stdin=f)
28
 
29
- sys.path.insert(0, 'stylegan2-pytorch')
30
 
31
  from model import Generator
32
 
33
- DESCRIPTION = '''# [TADNE](https://thisanimedoesnotexist.ai/) (This Anime Does Not Exist) interpolation
34
 
35
  Related Apps:
36
  - [TADNE](https://huggingface.co/spaces/hysts/TADNE)
37
  - [TADNE Image Viewer](https://huggingface.co/spaces/hysts/TADNE-image-viewer)
38
  - [TADNE Image Selector](https://huggingface.co/spaces/hysts/TADNE-image-selector)
39
  - [TADNE Image Search with DeepDanbooru](https://huggingface.co/spaces/hysts/TADNE-image-search-with-DeepDanbooru)
40
- '''
41
 
42
  MAX_SEED = np.iinfo(np.int32).max
43
 
@@ -50,13 +46,12 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
50
 
51
  def load_model(device: torch.device) -> nn.Module:
52
  model = Generator(512, 1024, 4, channel_multiplier=2)
53
- path = hf_hub_download('public-data/TADNE',
54
- 'models/aydao-anime-danbooru2019s-512-5268480.pt')
55
  checkpoint = torch.load(path)
56
- model.load_state_dict(checkpoint['g_ema'])
57
  model.eval()
58
  model.to(device)
59
- model.latent_avg = checkpoint['latent_avg'].to(device)
60
  with torch.inference_mode():
61
  z = torch.zeros((1, model.style_dim)).to(device)
62
  model([z], truncation=0.7, truncation_latent=model.latent_avg)
@@ -64,26 +59,27 @@ def load_model(device: torch.device) -> nn.Module:
64
 
65
 
66
  def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
67
- return torch.from_numpy(np.random.RandomState(seed).randn(
68
- 1, z_dim)).to(device).float()
69
 
70
 
71
  @torch.inference_mode()
72
- def generate_image(model: nn.Module, z: torch.Tensor, truncation_psi: float,
73
- randomize_noise: bool) -> np.ndarray:
74
- out, _ = model([z],
75
- truncation=truncation_psi,
76
- truncation_latent=model.latent_avg,
77
- randomize_noise=randomize_noise)
78
  out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
79
  return out[0].cpu().numpy()
80
 
81
 
82
  @torch.inference_mode()
83
- def generate_interpolated_images(seed0: int, seed1: int, num_intermediate: int,
84
- psi0: float, psi1: float,
85
- randomize_noise: bool, model: nn.Module,
86
- device: torch.device) -> list[np.ndarray]:
 
 
 
 
 
 
87
  seed0 = int(np.clip(seed0, 0, MAX_SEED))
88
  seed1 = int(np.clip(seed1, 0, MAX_SEED))
89
 
@@ -101,11 +97,9 @@ def generate_interpolated_images(seed0: int, seed1: int, num_intermediate: int,
101
  return res
102
 
103
 
104
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
105
  model = load_model(device)
106
- fn = functools.partial(generate_interpolated_images,
107
- model=model,
108
- device=device)
109
 
110
  examples = [
111
  [29703, 55376, 3, 0.7, 0.7, False],
@@ -115,41 +109,25 @@ examples = [
115
  [55376, 55376, 5, 0.3, 1.3, False],
116
  ]
117
 
118
- with gr.Blocks(css='style.css') as demo:
119
  gr.Markdown(DESCRIPTION)
120
  with gr.Row():
121
  with gr.Column():
122
- seed_1 = gr.Slider(label='Seed 1',
123
- minimum=0,
124
- maximum=MAX_SEED,
125
- step=1,
126
- value=29703)
127
- seed_2 = gr.Slider(label='Seed 2',
128
- minimum=0,
129
- maximum=MAX_SEED,
130
- step=1,
131
- value=55376)
132
  num_intermediate_frames = gr.Slider(
133
- label='Number of Intermediate Frames',
134
  minimum=1,
135
  maximum=21,
136
  step=1,
137
  value=3,
138
  )
139
- psi_1 = gr.Slider(label='Truncation psi 1',
140
- minimum=0,
141
- maximum=2,
142
- step=0.05,
143
- value=0.7)
144
- psi_2 = gr.Slider(label='Truncation psi 2',
145
- minimum=0,
146
- maximum=2,
147
- step=0.05,
148
- value=0.7)
149
- randomize_noise = gr.Checkbox(label='Randomize Noise', value=False)
150
- run_button = gr.Button('Run')
151
  with gr.Column():
152
- result = gr.Gallery(label='Output')
153
 
154
  inputs = [
155
  seed_1,
@@ -164,12 +142,12 @@ with gr.Blocks(css='style.css') as demo:
164
  inputs=inputs,
165
  outputs=result,
166
  fn=fn,
167
- cache_examples=os.getenv('CACHE_EXAMPLES') == '1',
168
  )
169
  run_button.click(
170
  fn=fn,
171
  inputs=inputs,
172
  outputs=result,
173
- api_name='run',
174
  )
175
  demo.queue(max_size=10).launch()
 
15
  import torch.nn as nn
16
  from huggingface_hub import hf_hub_download
17
 
18
+ if os.environ.get("SYSTEM") == "spaces":
19
+ with open("patch") as f:
20
+ subprocess.run(shlex.split("patch -p1"), cwd="stylegan2-pytorch", stdin=f)
 
 
21
  if not torch.cuda.is_available():
22
+ with open("patch-cpu") as f:
23
+ subprocess.run(shlex.split("patch -p1"), cwd="stylegan2-pytorch", stdin=f)
 
 
24
 
25
+ sys.path.insert(0, "stylegan2-pytorch")
26
 
27
  from model import Generator
28
 
29
+ DESCRIPTION = """# [TADNE](https://thisanimedoesnotexist.ai/) (This Anime Does Not Exist) interpolation
30
 
31
  Related Apps:
32
  - [TADNE](https://huggingface.co/spaces/hysts/TADNE)
33
  - [TADNE Image Viewer](https://huggingface.co/spaces/hysts/TADNE-image-viewer)
34
  - [TADNE Image Selector](https://huggingface.co/spaces/hysts/TADNE-image-selector)
35
  - [TADNE Image Search with DeepDanbooru](https://huggingface.co/spaces/hysts/TADNE-image-search-with-DeepDanbooru)
36
+ """
37
 
38
  MAX_SEED = np.iinfo(np.int32).max
39
 
 
46
 
47
  def load_model(device: torch.device) -> nn.Module:
48
  model = Generator(512, 1024, 4, channel_multiplier=2)
49
+ path = hf_hub_download("public-data/TADNE", "models/aydao-anime-danbooru2019s-512-5268480.pt")
 
50
  checkpoint = torch.load(path)
51
+ model.load_state_dict(checkpoint["g_ema"])
52
  model.eval()
53
  model.to(device)
54
+ model.latent_avg = checkpoint["latent_avg"].to(device)
55
  with torch.inference_mode():
56
  z = torch.zeros((1, model.style_dim)).to(device)
57
  model([z], truncation=0.7, truncation_latent=model.latent_avg)
 
59
 
60
 
61
  def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
62
+ return torch.from_numpy(np.random.RandomState(seed).randn(1, z_dim)).to(device).float()
 
63
 
64
 
65
  @torch.inference_mode()
66
+ def generate_image(model: nn.Module, z: torch.Tensor, truncation_psi: float, randomize_noise: bool) -> np.ndarray:
67
+ out, _ = model([z], truncation=truncation_psi, truncation_latent=model.latent_avg, randomize_noise=randomize_noise)
 
 
 
 
68
  out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
69
  return out[0].cpu().numpy()
70
 
71
 
72
  @torch.inference_mode()
73
+ def generate_interpolated_images(
74
+ seed0: int,
75
+ seed1: int,
76
+ num_intermediate: int,
77
+ psi0: float,
78
+ psi1: float,
79
+ randomize_noise: bool,
80
+ model: nn.Module,
81
+ device: torch.device,
82
+ ) -> list[np.ndarray]:
83
  seed0 = int(np.clip(seed0, 0, MAX_SEED))
84
  seed1 = int(np.clip(seed1, 0, MAX_SEED))
85
 
 
97
  return res
98
 
99
 
100
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
101
  model = load_model(device)
102
+ fn = functools.partial(generate_interpolated_images, model=model, device=device)
 
 
103
 
104
  examples = [
105
  [29703, 55376, 3, 0.7, 0.7, False],
 
109
  [55376, 55376, 5, 0.3, 1.3, False],
110
  ]
111
 
112
+ with gr.Blocks(css="style.css") as demo:
113
  gr.Markdown(DESCRIPTION)
114
  with gr.Row():
115
  with gr.Column():
116
+ seed_1 = gr.Slider(label="Seed 1", minimum=0, maximum=MAX_SEED, step=1, value=29703)
117
+ seed_2 = gr.Slider(label="Seed 2", minimum=0, maximum=MAX_SEED, step=1, value=55376)
 
 
 
 
 
 
 
 
118
  num_intermediate_frames = gr.Slider(
119
+ label="Number of Intermediate Frames",
120
  minimum=1,
121
  maximum=21,
122
  step=1,
123
  value=3,
124
  )
125
+ psi_1 = gr.Slider(label="Truncation psi 1", minimum=0, maximum=2, step=0.05, value=0.7)
126
+ psi_2 = gr.Slider(label="Truncation psi 2", minimum=0, maximum=2, step=0.05, value=0.7)
127
+ randomize_noise = gr.Checkbox(label="Randomize Noise", value=False)
128
+ run_button = gr.Button("Run")
 
 
 
 
 
 
 
 
129
  with gr.Column():
130
+ result = gr.Gallery(label="Output")
131
 
132
  inputs = [
133
  seed_1,
 
142
  inputs=inputs,
143
  outputs=result,
144
  fn=fn,
145
+ cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
146
  )
147
  run_button.click(
148
  fn=fn,
149
  inputs=inputs,
150
  outputs=result,
151
+ api_name="run",
152
  )
153
  demo.queue(max_size=10).launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- numpy==1.23.5
2
- Pillow==10.0.0
3
  torch==2.0.1
4
  torchvision==0.15.2
 
1
+ numpy==1.26.4
2
+ Pillow==10.3.0
3
  torch==2.0.1
4
  torchvision==0.15.2
style.css CHANGED
@@ -1,7 +1,11 @@
1
  h1 {
2
  text-align: center;
 
3
  }
4
 
5
  #duplicate-button {
6
  margin: auto;
 
 
 
7
  }
 
1
  h1 {
2
  text-align: center;
3
+ display: block;
4
  }
5
 
6
  #duplicate-button {
7
  margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
  }