magical-box / app.py
mac9087's picture
Create app.py
cb944d7 verified
raw
history blame
2.06 kB
# app.py
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.models.download import load_checkpoint
from point_e.models.text_to_image import create_model_and_diffusion as create_t2i
from point_e.models.image_to_3d import create_model_and_diffusion as create_i2p
from point_e.util.plotting import save_point_cloud
import tempfile
import os
app = Flask(__name__)
CORS(app)
# Load models once on startup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("Loading text-to-image model...")
t2i_model, t2i_diffusion = create_t2i(DIFFUSION_CONFIGS['base40'], device)
t2i_model.load_state_dict(load_checkpoint('base40', device))
t2i_model.eval()
print("Loading image-to-3d model...")
i2p_model, i2p_diffusion = create_i2p(DIFFUSION_CONFIGS['base40'], device)
i2p_model.load_state_dict(load_checkpoint('base40-img2pc', device))
i2p_model.eval()
@app.route('/generate', methods=['POST'])
def generate():
data = request.json
prompt = data.get('prompt', 'a red car')
# Generate image from text
print("Generating image from text...")
samples = t2i_diffusion.sample_loop(
model=t2i_model,
shape=(1, 3, 64, 64),
device=device,
clip_denoised=True,
model_kwargs={"text": [prompt]},
progress=True
)
image = t2i_model.output_to_image(samples[0])
# Generate point cloud from image
print("Generating point cloud...")
pc = i2p_diffusion.sample_loop(
model=i2p_model,
shape=(1, 3, 64, 64),
device=device,
clip_denoised=True,
model_kwargs={"images": [image]},
progress=True
)
# Save to temporary file
tmp_dir = tempfile.mkdtemp()
output_path = os.path.join(tmp_dir, 'output.ply')
save_point_cloud(pc[0], output_path)
return jsonify({'model_path': output_path})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)