Bils commited on
Commit
df2a904
·
verified ·
1 Parent(s): aefc26a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -40,25 +40,20 @@ MODEL_CONFIG = {
40
  # -------------------------------
41
  # Model Manager with Cache
42
  # -------------------------------
 
43
  class ModelManager:
44
- def __init__(self):
45
- self.llama_pipelines = {}
46
- self.musicgen_model = None
47
- self.tts_models = {}
48
-
49
  def get_llama_pipeline(self, model_id, token):
50
  if model_id not in self.llama_pipelines:
51
  tokenizer = AutoTokenizer.from_pretrained(
52
  model_id,
53
- use_auth_token=token,
54
- legacy=False # Important for compatibility
55
  )
56
  model = AutoModelForCausalLM.from_pretrained(
57
  model_id,
58
- use_auth_token=token,
59
  torch_dtype=torch.float16,
60
- device_map="auto",
61
- low_cpu_mem_usage=True # Reduces memory pressure
62
  )
63
  self.llama_pipelines[model_id] = pipeline(
64
  "text-generation",
@@ -67,7 +62,7 @@ class ModelManager:
67
  device_map="auto"
68
  )
69
  return self.llama_pipelines[model_id]
70
-
71
  def get_musicgen_model(self):
72
  if not self.musicgen_model:
73
  self.musicgen_model = MusicgenForConditionalGeneration.from_pretrained(
 
40
  # -------------------------------
41
  # Model Manager with Cache
42
  # -------------------------------
43
+
44
  class ModelManager:
 
 
 
 
 
45
  def get_llama_pipeline(self, model_id, token):
46
  if model_id not in self.llama_pipelines:
47
  tokenizer = AutoTokenizer.from_pretrained(
48
  model_id,
49
+ token=token,
50
+ legacy=False # Critical for tokenizers 0.19.x compatibility
51
  )
52
  model = AutoModelForCausalLM.from_pretrained(
53
  model_id,
54
+ token=token,
55
  torch_dtype=torch.float16,
56
+ device_map="auto"
 
57
  )
58
  self.llama_pipelines[model_id] = pipeline(
59
  "text-generation",
 
62
  device_map="auto"
63
  )
64
  return self.llama_pipelines[model_id]
65
+
66
  def get_musicgen_model(self):
67
  if not self.musicgen_model:
68
  self.musicgen_model = MusicgenForConditionalGeneration.from_pretrained(