mike23415 commited on
Commit
9a14904
·
verified ·
1 Parent(s): a45f4a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -115
app.py CHANGED
@@ -1,136 +1,56 @@
1
  import os
2
- import logging
3
- import pip
4
- from flask import Flask, request, jsonify, send_file
5
  import torch
 
 
6
  from PIL import Image
7
- import io
8
-
9
- # Configure logging to stdout instead of files
10
- logging.basicConfig(
11
- level=logging.INFO,
12
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
13
- handlers=[logging.StreamHandler()]
14
- )
15
- logger = logging.getLogger(__name__)
16
-
17
- # Try to update the required packages
18
- try:
19
- logger.info("Updating huggingface_hub and diffusers...")
20
- pip.main(['install', '--upgrade', 'huggingface_hub', '--quiet'])
21
- pip.main(['install', '--upgrade', 'diffusers', '--quiet'])
22
- except Exception as e:
23
- logger.warning(f"Failed to update libraries: {str(e)}")
24
-
25
- # Set Hugging Face cache directory to a writable path
26
- os.environ['HF_HOME'] = '/tmp/hf_home'
27
- os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
28
-
29
- # Create cache directories if they don't exist
30
- os.makedirs('/tmp/hf_home', exist_ok=True)
31
- os.makedirs('/tmp/cache', exist_ok=True)
32
- os.makedirs('/tmp/diffusers_cache', exist_ok=True)
33
-
34
- # Global variable for the model
35
- pipe = None
36
 
37
- # Initialize the model at startup
38
- def load_model():
39
- global pipe
40
- try:
41
- logger.info("Loading Zero123Plus model...")
42
- # Import here to ensure the environment variables are set before import
43
- from diffusers import AutoPipelineForImage2Image
44
- from huggingface_hub import snapshot_download
45
-
46
- try:
47
- # First try to download the model files
48
- model_path = snapshot_download(
49
- "sudo-ai/zero123plus-v1.2",
50
- cache_dir="/tmp/diffusers_cache",
51
- local_files_only=False
52
- )
53
-
54
- # Then load from local path
55
- pipe = AutoPipelineForImage2Image.from_pretrained(
56
- model_path,
57
- torch_dtype=torch.float32,
58
- safety_checker=None,
59
- low_cpu_mem_usage=True
60
- )
61
- except Exception as download_error:
62
- logger.warning(f"Failed to download using snapshot_download: {str(download_error)}")
63
-
64
- # Fallback to direct loading with local_files_only=False
65
- pipe = AutoPipelineForImage2Image.from_pretrained(
66
- "sudo-ai/zero123plus-v1.2",
67
- torch_dtype=torch.float32,
68
- cache_dir="/tmp/diffusers_cache",
69
- safety_checker=None,
70
- low_cpu_mem_usage=True,
71
- local_files_only=False
72
- )
73
-
74
- pipe.to("cpu")
75
- logger.info("Model loaded successfully")
76
- return True
77
- except Exception as e:
78
- logger.error(f"Error loading model: {str(e)}")
79
- return False
80
 
81
- # Initialize Flask app
82
  app = Flask(__name__)
83
 
84
- # Load the model immediately
85
- load_model()
86
 
87
- @app.route("/", methods=["GET"])
88
- def index():
89
- # Check if logs parameter is present
90
- if request.args.get('logs') == 'container':
91
- return jsonify({"message": "Zero123Plus API is running.", "status": "logs viewed"})
92
- return jsonify({"message": "Zero123Plus API is running."})
 
 
 
 
 
 
 
 
93
 
94
  @app.route("/generate", methods=["POST"])
95
- def generate():
96
- try:
97
- global pipe
98
- # Check if model is loaded
99
- if pipe is None:
100
- success = load_model()
101
- if not success:
102
- return jsonify({"error": "Failed to initialize model"}), 500
103
-
104
- if 'image' not in request.files:
105
- logger.warning("No image uploaded")
106
- return jsonify({"error": "No image uploaded"}), 400
107
 
108
- image_file = request.files['image']
109
- input_image = Image.open(image_file).convert("RGB")
110
- logger.info(f"Received image of size {input_image.size}")
111
 
112
- # Get optional parameters with defaults
113
- num_steps = int(request.form.get('num_inference_steps', 25))
114
-
115
- logger.info(f"Starting image generation with {num_steps} steps")
116
- # Generate new views
117
- result = pipe(
118
- image=input_image,
119
- num_inference_steps=num_steps
120
- )
121
-
122
  output_image = result.images[0]
123
- logger.info(f"Generated image of size {output_image.size}")
124
 
125
- # Save to a BytesIO object
126
- img_io = io.BytesIO()
127
  output_image.save(img_io, 'PNG')
128
  img_io.seek(0)
129
-
130
  return send_file(img_io, mimetype='image/png')
131
-
132
  except Exception as e:
133
- logger.error(f"Error during generation: {str(e)}")
134
  return jsonify({"error": str(e)}), 500
135
 
136
  if __name__ == "__main__":
 
1
  import os
 
 
 
2
  import torch
3
+ from flask import Flask, request, jsonify, send_file
4
+ from diffusers import DiffusionPipeline
5
  from PIL import Image
6
+ from io import BytesIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Optional: logs
9
+ import logging
10
+ logging.basicConfig(level=logging.INFO)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Flask app setup
13
  app = Flask(__name__)
14
 
15
+ # Load model once at startup
16
+ logging.info("Loading Zero123Plus pipeline...")
17
 
18
+ MODEL_ID = "sudo-ai/zero123plus-v1.2" # Or your preferred model
19
+ try:
20
+ pipe = DiffusionPipeline.from_pretrained(
21
+ MODEL_ID,
22
+ torch_dtype=torch.float16,
23
+ variant="fp16"
24
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
25
+ except Exception as e:
26
+ logging.error(f"Error loading model: {e}")
27
+ pipe = None
28
+
29
+ @app.route("/")
30
+ def health_check():
31
+ return jsonify({"status": "Zero123 API is running!"})
32
 
33
  @app.route("/generate", methods=["POST"])
34
+ def generate_image():
35
+ if pipe is None:
36
+ return jsonify({"error": "Model not loaded properly"}), 500
 
 
 
 
 
 
 
 
 
37
 
38
+ data = request.files.get("image")
39
+ if not data:
40
+ return jsonify({"error": "No image provided"}), 400
41
 
42
+ try:
43
+ input_image = Image.open(data).convert("RGB")
44
+ result = pipe(image=input_image, num_inference_steps=30)
 
 
 
 
 
 
 
45
  output_image = result.images[0]
 
46
 
47
+ # Return as image file
48
+ img_io = BytesIO()
49
  output_image.save(img_io, 'PNG')
50
  img_io.seek(0)
 
51
  return send_file(img_io, mimetype='image/png')
 
52
  except Exception as e:
53
+ logging.error(f"Generation error: {e}")
54
  return jsonify({"error": str(e)}), 500
55
 
56
  if __name__ == "__main__":