mike23415 commited on
Commit
8a1bb06
·
verified ·
1 Parent(s): 5ba0d31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -13
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import logging
 
3
  from flask import Flask, request, jsonify, send_file
4
  import torch
5
  from PIL import Image
@@ -13,6 +14,14 @@ logging.basicConfig(
13
  )
14
  logger = logging.getLogger(__name__)
15
 
 
 
 
 
 
 
 
 
16
  # Set Hugging Face cache directory to a writable path
17
  os.environ['HF_HOME'] = '/tmp/hf_home'
18
  os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
@@ -32,15 +41,36 @@ def load_model():
32
  logger.info("Loading Zero123Plus model...")
33
  # Import here to ensure the environment variables are set before import
34
  from diffusers import AutoPipelineForImage2Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Use AutoPipelineForImage2Image which should work with any image-to-image model
37
- pipe = AutoPipelineForImage2Image.from_pretrained(
38
- "sudo-ai/zero123plus-v1.2",
39
- torch_dtype=torch.float32,
40
- cache_dir="/tmp/diffusers_cache",
41
- safety_checker=None,
42
- low_cpu_mem_usage=True
43
- )
44
  pipe.to("cpu")
45
  logger.info("Model loaded successfully")
46
  return True
@@ -48,18 +78,17 @@ def load_model():
48
  logger.error(f"Error loading model: {str(e)}")
49
  return False
50
 
51
- # Don't try to load the model at startup - we'll load it on the first request
52
- # This prevents the app from crashing if the model can't be loaded immediately
53
-
54
  app = Flask(__name__)
55
 
56
  # Load the model immediately
57
  load_model()
58
 
59
- app = Flask(__name__)
60
-
61
  @app.route("/", methods=["GET"])
62
  def index():
 
 
 
63
  return jsonify({"message": "Zero123Plus API is running."})
64
 
65
  @app.route("/generate", methods=["POST"])
 
1
  import os
2
  import logging
3
+ import pip
4
  from flask import Flask, request, jsonify, send_file
5
  import torch
6
  from PIL import Image
 
14
  )
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Try to update the required packages
18
+ try:
19
+ logger.info("Updating huggingface_hub and diffusers...")
20
+ pip.main(['install', '--upgrade', 'huggingface_hub', '--quiet'])
21
+ pip.main(['install', '--upgrade', 'diffusers', '--quiet'])
22
+ except Exception as e:
23
+ logger.warning(f"Failed to update libraries: {str(e)}")
24
+
25
  # Set Hugging Face cache directory to a writable path
26
  os.environ['HF_HOME'] = '/tmp/hf_home'
27
  os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
 
41
  logger.info("Loading Zero123Plus model...")
42
  # Import here to ensure the environment variables are set before import
43
  from diffusers import AutoPipelineForImage2Image
44
+ from huggingface_hub import snapshot_download
45
+
46
+ try:
47
+ # First try to download the model files
48
+ model_path = snapshot_download(
49
+ "sudo-ai/zero123plus-v1.2",
50
+ cache_dir="/tmp/diffusers_cache",
51
+ local_files_only=False
52
+ )
53
+
54
+ # Then load from local path
55
+ pipe = AutoPipelineForImage2Image.from_pretrained(
56
+ model_path,
57
+ torch_dtype=torch.float32,
58
+ safety_checker=None,
59
+ low_cpu_mem_usage=True
60
+ )
61
+ except Exception as download_error:
62
+ logger.warning(f"Failed to download using snapshot_download: {str(download_error)}")
63
+
64
+ # Fallback to direct loading with local_files_only=False
65
+ pipe = AutoPipelineForImage2Image.from_pretrained(
66
+ "sudo-ai/zero123plus-v1.2",
67
+ torch_dtype=torch.float32,
68
+ cache_dir="/tmp/diffusers_cache",
69
+ safety_checker=None,
70
+ low_cpu_mem_usage=True,
71
+ local_files_only=False
72
+ )
73
 
 
 
 
 
 
 
 
 
74
  pipe.to("cpu")
75
  logger.info("Model loaded successfully")
76
  return True
 
78
  logger.error(f"Error loading model: {str(e)}")
79
  return False
80
 
81
+ # Initialize Flask app
 
 
82
  app = Flask(__name__)
83
 
84
  # Load the model immediately
85
  load_model()
86
 
 
 
87
  @app.route("/", methods=["GET"])
88
  def index():
89
+ # Check if logs parameter is present
90
+ if request.args.get('logs') == 'container':
91
+ return jsonify({"message": "Zero123Plus API is running.", "status": "logs viewed"})
92
  return jsonify({"message": "Zero123Plus API is running."})
93
 
94
  @app.route("/generate", methods=["POST"])