jameszokah commited on
Commit
2acc39d
·
1 Parent(s): 5b49b7d

updated the path to the APP_DIR in the main file

Browse files
Files changed (1) hide show
  1. app/main.py +20 -18
app/main.py CHANGED
@@ -54,18 +54,19 @@ async def lifespan(app: FastAPI):
54
  app.state.logger = logger # Make logger available to routes
55
 
56
  # Create necessary directories - use persistent locations
57
- os.makedirs("/app/models", exist_ok=True)
58
- os.makedirs("/app/tokenizers", exist_ok=True)
59
- os.makedirs("/app/voice_memories", exist_ok=True)
60
- os.makedirs("/app/voice_references", exist_ok=True)
61
- os.makedirs("/app/voice_profiles", exist_ok=True)
62
- os.makedirs("/app/cloned_voices", exist_ok=True)
63
- os.makedirs("/app/audio_cache", exist_ok=True)
64
- os.makedirs("/app/static", exist_ok=True)
 
65
 
66
  # Set tokenizer cache
67
  try:
68
- os.environ["TRANSFORMERS_CACHE"] = "/app/tokenizers"
69
  logger.info(f"Set tokenizer cache to: {os.environ['TRANSFORMERS_CACHE']}")
70
  except Exception as e:
71
  logger.error(f"Error setting tokenizer cache: {e}")
@@ -109,7 +110,7 @@ async def lifespan(app: FastAPI):
109
  app.state.device_map = device_map
110
 
111
  # Check if model file exists
112
- model_path = os.path.join("/app/models", "ckpt.pt")
113
  if not os.path.exists(model_path):
114
  # Try to download at runtime if not present
115
  logger.info("Model not found. Attempting to download...")
@@ -125,7 +126,7 @@ async def lifespan(app: FastAPI):
125
  model_path = hf_hub_download(
126
  repo_id="sesame/csm-1b",
127
  filename="ckpt.pt",
128
- local_dir="/app/models"
129
  )
130
  download_time = time.time() - download_start
131
  logger.info(f"Model downloaded to {model_path} in {download_time:.2f} seconds")
@@ -180,7 +181,7 @@ async def lifespan(app: FastAPI):
180
  logger.info("Initializing voice cloning system...")
181
  from app.voice_cloning import VoiceCloner, CLONED_VOICES_DIR
182
  # Update the cloned voices directory to use the persistent volume
183
- app.state.cloned_voices_dir = "/app/cloned_voices" # Store path in app state for access
184
  os.makedirs(app.state.cloned_voices_dir, exist_ok=True)
185
  CLONED_VOICES_DIR = app.state.cloned_voices_dir # Update the module constant
186
 
@@ -362,7 +363,7 @@ async def lifespan(app: FastAPI):
362
  # Set up audio cache
363
  app.state.audio_cache_enabled = os.environ.get("ENABLE_AUDIO_CACHE", "true").lower() == "true"
364
  if app.state.audio_cache_enabled:
365
- app.state.audio_cache_dir = "/app/audio_cache"
366
  logger.info(f"Audio cache enabled, cache dir: {app.state.audio_cache_dir}")
367
 
368
  # Log GPU utilization after model loading
@@ -497,11 +498,12 @@ app.add_middleware(
497
  )
498
 
499
  # Create static and other required directories
500
- os.makedirs("/app/static", exist_ok=True)
501
- os.makedirs("/app/cloned_voices", exist_ok=True)
 
502
 
503
  # Mount the static files directory
504
- app.mount("/static", StaticFiles(directory="/app/static"), name="static")
505
 
506
  # Include routers
507
  app.include_router(api_router, prefix="/api/v1")
@@ -588,13 +590,13 @@ async def version():
588
  @app.get("/voice-cloning", include_in_schema=False)
589
  async def voice_cloning_ui():
590
  """Voice cloning UI endpoint."""
591
- return FileResponse("/app/static/voice-cloning.html")
592
 
593
  # Streaming demo endpoint
594
  @app.get("/streaming-demo", include_in_schema=False)
595
  async def streaming_demo():
596
  """Streaming TTS demo endpoint."""
597
- return FileResponse("/app/static/streaming-demo.html")
598
 
599
  @app.get("/", include_in_schema=False)
600
  async def root():
 
54
  app.state.logger = logger # Make logger available to routes
55
 
56
  # Create necessary directories - use persistent locations
57
+ APP_DIR = os.path.join(os.environ['HOME'], 'app')
58
+ os.makedirs(os.path.join(APP_DIR, "models"), exist_ok=True)
59
+ os.makedirs(os.path.join(APP_DIR, "tokenizers"), exist_ok=True)
60
+ os.makedirs(os.path.join(APP_DIR, "voice_memories"), exist_ok=True)
61
+ os.makedirs(os.path.join(APP_DIR, "voice_references"), exist_ok=True)
62
+ os.makedirs(os.path.join(APP_DIR, "voice_profiles"), exist_ok=True)
63
+ os.makedirs(os.path.join(APP_DIR, "cloned_voices"), exist_ok=True)
64
+ os.makedirs(os.path.join(APP_DIR, "audio_cache"), exist_ok=True)
65
+ os.makedirs(os.path.join(APP_DIR, "static"), exist_ok=True)
66
 
67
  # Set tokenizer cache
68
  try:
69
+ os.environ["TRANSFORMERS_CACHE"] = os.path.join(APP_DIR, "tokenizers")
70
  logger.info(f"Set tokenizer cache to: {os.environ['TRANSFORMERS_CACHE']}")
71
  except Exception as e:
72
  logger.error(f"Error setting tokenizer cache: {e}")
 
110
  app.state.device_map = device_map
111
 
112
  # Check if model file exists
113
+ model_path = os.path.join(APP_DIR, "models", "ckpt.pt")
114
  if not os.path.exists(model_path):
115
  # Try to download at runtime if not present
116
  logger.info("Model not found. Attempting to download...")
 
126
  model_path = hf_hub_download(
127
  repo_id="sesame/csm-1b",
128
  filename="ckpt.pt",
129
+ local_dir=APP_DIR
130
  )
131
  download_time = time.time() - download_start
132
  logger.info(f"Model downloaded to {model_path} in {download_time:.2f} seconds")
 
181
  logger.info("Initializing voice cloning system...")
182
  from app.voice_cloning import VoiceCloner, CLONED_VOICES_DIR
183
  # Update the cloned voices directory to use the persistent volume
184
+ app.state.cloned_voices_dir = os.path.join(APP_DIR, "cloned_voices") # Store path in app state for access
185
  os.makedirs(app.state.cloned_voices_dir, exist_ok=True)
186
  CLONED_VOICES_DIR = app.state.cloned_voices_dir # Update the module constant
187
 
 
363
  # Set up audio cache
364
  app.state.audio_cache_enabled = os.environ.get("ENABLE_AUDIO_CACHE", "true").lower() == "true"
365
  if app.state.audio_cache_enabled:
366
+ app.state.audio_cache_dir = os.path.join(APP_DIR, "audio_cache")
367
  logger.info(f"Audio cache enabled, cache dir: {app.state.audio_cache_dir}")
368
 
369
  # Log GPU utilization after model loading
 
498
  )
499
 
500
  # Create static and other required directories
501
+ APP_DIR = os.path.join(os.environ['HOME'], 'app')
502
+ os.makedirs(os.path.join(APP_DIR, "static"), exist_ok=True)
503
+ os.makedirs(os.path.join(APP_DIR, "cloned_voices"), exist_ok=True)
504
 
505
  # Mount the static files directory
506
+ app.mount("/static", StaticFiles(directory=os.path.join(APP_DIR, "static")), name="static")
507
 
508
  # Include routers
509
  app.include_router(api_router, prefix="/api/v1")
 
590
  @app.get("/voice-cloning", include_in_schema=False)
591
  async def voice_cloning_ui():
592
  """Voice cloning UI endpoint."""
593
+ return FileResponse(os.path.join(APP_DIR, "static/voice-cloning.html"))
594
 
595
  # Streaming demo endpoint
596
  @app.get("/streaming-demo", include_in_schema=False)
597
  async def streaming_demo():
598
  """Streaming TTS demo endpoint."""
599
+ return FileResponse(os.path.join(APP_DIR, "static/streaming-demo.html"))
600
 
601
  @app.get("/", include_in_schema=False)
602
  async def root():