Spaces:
Sleeping
Sleeping
# app.py | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
import torch | |
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config | |
from point_e.models.download import load_checkpoint | |
from point_e.models.text_to_image import create_model_and_diffusion as create_t2i | |
from point_e.models.image_to_3d import create_model_and_diffusion as create_i2p | |
from point_e.util.plotting import save_point_cloud | |
import tempfile | |
import os | |
app = Flask(__name__) | |
CORS(app) | |
# Load models once on startup | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
print("Loading text-to-image model...") | |
t2i_model, t2i_diffusion = create_t2i(DIFFUSION_CONFIGS['base40'], device) | |
t2i_model.load_state_dict(load_checkpoint('base40', device)) | |
t2i_model.eval() | |
print("Loading image-to-3d model...") | |
i2p_model, i2p_diffusion = create_i2p(DIFFUSION_CONFIGS['base40'], device) | |
i2p_model.load_state_dict(load_checkpoint('base40-img2pc', device)) | |
i2p_model.eval() | |
def generate(): | |
data = request.json | |
prompt = data.get('prompt', 'a red car') | |
# Generate image from text | |
print("Generating image from text...") | |
samples = t2i_diffusion.sample_loop( | |
model=t2i_model, | |
shape=(1, 3, 64, 64), | |
device=device, | |
clip_denoised=True, | |
model_kwargs={"text": [prompt]}, | |
progress=True | |
) | |
image = t2i_model.output_to_image(samples[0]) | |
# Generate point cloud from image | |
print("Generating point cloud...") | |
pc = i2p_diffusion.sample_loop( | |
model=i2p_model, | |
shape=(1, 3, 64, 64), | |
device=device, | |
clip_denoised=True, | |
model_kwargs={"images": [image]}, | |
progress=True | |
) | |
# Save to temporary file | |
tmp_dir = tempfile.mkdtemp() | |
output_path = os.path.join(tmp_dir, 'output.ply') | |
save_point_cloud(pc[0], output_path) | |
return jsonify({'model_path': output_path}) | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) | |