lmoss commited on
Commit
806f947
·
1 Parent(s): 77e0f29

added berea path download

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -1,29 +1,32 @@
1
  import streamlit as st
2
- import streamlit.components.v1 as components
3
  import pyvista as pv
4
- from pyvista import examples
5
- import numpy as np
6
  from dcgan import DCGAN3D_G
7
  import torch
8
  import requests
 
9
 
10
- url = "https://raw.githubusercontent.com/LukasMosser/PorousMediaGan/raw/master/checkpoints/berea/berea_generator_epoch_24.pth"
11
 
12
  # If repo is private - we need to add a token in header:
13
  resp = requests.get(url)
 
 
 
 
 
14
  print(resp.status_code)
15
  st.text(resp.status_code)
 
16
  pv.set_plot_theme("document")
17
  pl = pv.Plotter(shape=(1, 1),
18
  window_size=(800, 800))
19
-
20
  netG = DCGAN3D_G(64, 512, 1, 32, 1)
21
  netG.load_state_dict(torch.load("berea_generator_epoch_24.pth"))
22
  z = torch.randn(1, 512, 5, 5, 5)
23
  with torch.no_grad():
24
  X = netG(z)
25
- print(X.size())
26
- print(X.min(), X.max())
27
  st.image((X[0, 0, 32].numpy()+1)/2, output_format="png")
28
  """
29
  data = examples.load_channels()
 
1
  import streamlit as st
 
2
  import pyvista as pv
 
 
3
  from dcgan import DCGAN3D_G
4
  import torch
5
  import requests
6
+ import time
7
 
8
+ url = "https://github.com/LukasMosser/PorousMediaGan/blob/master/checkpoints/berea/berea_generator_epoch_24.pth?raw=true"
9
 
10
  # If repo is private - we need to add a token in header:
11
  resp = requests.get(url)
12
+
13
+ with open('berea_generator_epoch_24.pth', 'wb') as f:
14
+ f.write(resp.content)
15
+ time.sleep(5)
16
+
17
  print(resp.status_code)
18
  st.text(resp.status_code)
19
+
20
  pv.set_plot_theme("document")
21
  pl = pv.Plotter(shape=(1, 1),
22
  window_size=(800, 800))
23
+ print(torch.load("berea_generator_epoch_24.pth"))
24
  netG = DCGAN3D_G(64, 512, 1, 32, 1)
25
  netG.load_state_dict(torch.load("berea_generator_epoch_24.pth"))
26
  z = torch.randn(1, 512, 5, 5, 5)
27
  with torch.no_grad():
28
  X = netG(z)
29
+
 
30
  st.image((X[0, 0, 32].numpy()+1)/2, output_format="png")
31
  """
32
  data = examples.load_channels()