Spaces:
Build error
Build error
File size: 3,985 Bytes
19677a1 0d35ba8 19677a1 e4b8bbd 19677a1 0d35ba8 0071656 19677a1 0071656 19677a1 13bc063 0d35ba8 e4b8bbd 0071656 bd597e9 5272de4 e4b8bbd 0071656 bd597e9 19677a1 fd209d1 bb689c9 fd209d1 0d35ba8 0071656 0d35ba8 0071656 19677a1 0d35ba8 19677a1 e4b8bbd 6cbae78 bb689c9 6cbae78 5272de4 4f54252 c1f7cd5 4f54252 5272de4 0071656 13bc063 8bc9869 1a30119 0071656 1a30119 bb689c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import builtins
import math
import json
import streamlit as st
import gdown
# from google_drive_downloader import GoogleDriveDownloader as gdd
from demo.src.models import load_trained_model
from demo.src.utils import render_predict_from_pose, predict_to_image
# from demo.src.config import MODEL_DIR, MODEL_NAME, FILE_ID
st.set_page_config(page_title="DietNeRF")
with open("config.json") as f:
cfg = json.loads(f.read())
def select_model():
obj_select = st.selectbox("Select a scene", ("Chair", "Lego", "Ship", "Hotdog"))
# if obj_select == "Chair":
# FILE_ID = "17dj0pQieo94TozFv-noSBkXebduij1aM"
# MODEL_DIR = "models"
# MODEL_NAME = "diet_nerf_chair"
# elif obj_select == "Lego":
# FILE_ID = "1D9I-qIVMPaxuCHfUWPWMHaoLYtAmCjwI"
# MODEL_DIR = "models"
# MODEL_NAME = "diet_nerf_lego"
# elif obj_select == "Ship":
# FILE_ID = "14ZeJ86ETQr8dtu6CFoxU-ifvniHKo_Dt"
# MODEL_DIR = "models"
# MODEL_NAME = "diet_nerf_ship"
# elif obj_select == "Hotdog":
# FILE_ID = "11vNlR4lMvV_AVFgVjZmKMrMWGVG7qhNu"
# MODEL_DIR = "models"
# MODEL_NAME = "diet_nerf_hotdog"
MODEL_DIR = "models"
MODEL_NAME = cfg[obj_select]["DIET_NERF_MODEL_NAME"]
FILE_ID = cfg[obj_select]["FILE_ID"]
return MODEL_DIR, MODEL_NAME, FILE_ID
st.title("DietNeRF")
caption = (
"DietNeRF achieves SoTA few-shot learning capacity in 3D model reconstruction. "
"Thanks to the 2D supervision by CLIP (aka semantic loss), "
"it can render novel and challenging views with ONLY 8 training images, "
"outperforming original NeRF!"
)
st.markdown(caption)
st.markdown("")
MODEL_DIR, MODEL_NAME, FILE_ID = select_model()
@st.cache
def download_model():
os.makedirs(MODEL_DIR, exist_ok=True)
_model_path = os.path.join(MODEL_DIR, MODEL_NAME)
# gdd.download_file_from_google_drive(file_id=FILE_ID,
# dest_path=_model_path,
# unzip=True)
url = f"https://drive.google.com/uc?id={FILE_ID}"
gdown.download(url, _model_path, quiet=False)
print(f"Model downloaded from google drive: {_model_path}")
@st.cache(show_spinner=False, allow_output_mutation=True)
def fetch_model():
model, state = load_trained_model(MODEL_DIR, MODEL_NAME)
return model, state
model_path = os.path.join(MODEL_DIR, MODEL_NAME)
if not os.path.isfile(model_path):
download_model()
model, state = fetch_model()
pi = math.pi
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
</style>
<p class="aligncenter">
<img src="https://user-images.githubusercontent.com/77657524/126361638-4aad58e8-4efb-4fc5-bf78-f53d03799e1e.png" width="410" height="400"/>
</p>
""",
unsafe_allow_html=True,
)
st.sidebar.markdown(
"""
<p style='text-align: center'>
<a href="https://github.com/codestella/putting-nerf-on-a-diet" target="_blank">GitHub</a> | <a href="https://www.notion.so/DietNeRF-Putting-NeRF-on-a-Diet-4aeddae95d054f1d91686f02bdb74745" target="_blank">Project Report</a>
</p>
""",
unsafe_allow_html=True,
)
st.sidebar.header("SELECT YOUR VIEW DIRECTION")
theta = st.sidebar.slider(
"Theta", min_value=-pi, max_value=pi, step=0.5, value=0.0, help="Rotational angle in Horizontal direction"
)
phi = st.sidebar.slider(
"Phi", min_value=0.0, max_value=0.5 * pi, step=0.1, value=1.0, help="Rotational angle in Vertical direction"
)
radius = st.sidebar.slider(
"Radius", min_value=2.0, max_value=6.0, step=1.0, value=3.0, help="Distance between object and the viewer"
)
st.markdown("")
with st.spinner("Rendering Image, it may take 2-3 mins. So, why don't you read our report in the meantime"):
pred_color, _ = render_predict_from_pose(state, theta, phi, radius)
im = predict_to_image(pred_color)
w, _ = im.size
new_w = int(2 * w)
im = im.resize(size=(new_w, new_w))
st.image(im, use_column_width="auto")
|