mike23415 commited on
Commit
2e2dac7
·
verified ·
1 Parent(s): 51a59b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -36
app.py CHANGED
@@ -1,54 +1,62 @@
1
- import os
2
- import sys
3
-
4
- # Add this before importing pipeline
5
- sys.path.append(os.path.dirname(__file__))
6
-
7
- from flask import Flask, request, jsonify
8
- from PIL import Image
9
  import torch
10
- from zero_pipeline import Zero123PlusPipeline # Now this will work
 
 
 
11
 
12
- app = Flask(__name__)
 
13
 
14
- # Load the pipeline once when the app starts
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- print(f"Running on device: {device}")
17
 
18
- pipe = Zero123PlusPipeline.from_pretrained(
19
- "sudo-ai/zero123plus-v1.2",
20
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
21
- )
22
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @app.route("/")
25
- def index():
26
- return "Zero123Plus API is running!"
27
 
28
- @app.route("/predict", methods=["POST"])
29
- def predict():
30
- if 'image' not in request.files:
31
- return jsonify({"error": "No image uploaded"}), 400
32
 
33
- image = request.files["image"]
34
  try:
35
- input_image = Image.open(image).convert("RGB")
 
 
 
 
36
 
37
- result = pipe(input_image, num_inference_steps=75, num_images_per_prompt=4)
 
38
 
39
- images = result.images # List of PIL Images
40
- output_dir = "outputs"
41
- os.makedirs(output_dir, exist_ok=True)
42
- saved_paths = []
43
 
44
- for i, img in enumerate(images):
45
- path = os.path.join(output_dir, f"output_{i}.png")
46
- img.save(path)
47
- saved_paths.append(path)
48
 
49
- return jsonify({"outputs": saved_paths})
50
 
51
  except Exception as e:
 
52
  return jsonify({"error": str(e)}), 500
53
 
54
  if __name__ == "__main__":
 
1
+ import io
2
+ import base64
 
 
 
 
 
 
3
  import torch
4
+ from flask import Flask, request, jsonify, send_file
5
+ from diffusers import DiffusionPipeline
6
+ from PIL import Image
7
+ import logging
8
 
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
+ app = Flask(__name__)
 
 
13
 
14
+ # Load the model once at startup (on CPU)
15
+ try:
16
+ logger.info("Loading Zero123Plus pipeline...")
17
+ pipe = DiffusionPipeline.from_pretrained(
18
+ "sudo-ai/zero123plus-v1.2",
19
+ torch_dtype=torch.float32, # CPU needs float32
20
+ )
21
+ pipe.to("cpu")
22
+ logger.info("=== Application Startup at CPU mode =====")
23
+ except Exception as e:
24
+ logger.error(f"Error loading model: {e}")
25
+ pipe = None
26
+
27
+ def pil_to_base64(image):
28
+ buffer = io.BytesIO()
29
+ image.save(buffer, format="PNG")
30
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
31
 
32
  @app.route("/")
33
+ def home():
34
+ return "Zero123Plus CPU API is running!"
35
 
36
+ @app.route("/generate", methods=["POST"])
37
+ def generate():
38
+ if pipe is None:
39
+ return jsonify({"error": "Model not loaded"}), 500
40
 
 
41
  try:
42
+ data = request.get_json()
43
+ image_data = data.get("image")
44
+
45
+ if not image_data:
46
+ return jsonify({"error": "No image provided"}), 400
47
 
48
+ if image_data.startswith("data:image"):
49
+ image_data = image_data.split(",")[1]
50
 
51
+ image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
 
 
 
52
 
53
+ result = pipe(image)
54
+ output_image = result.images[0]
 
 
55
 
56
+ return jsonify({"image": f"data:image/png;base64,{pil_to_base64(output_image)}"})
57
 
58
  except Exception as e:
59
+ logger.error(f"Error generating image: {e}")
60
  return jsonify({"error": str(e)}), 500
61
 
62
  if __name__ == "__main__":