Bils commited on
Commit
5607a62
·
verified ·
1 Parent(s): df2a904

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -42,18 +42,25 @@ MODEL_CONFIG = {
42
  # -------------------------------
43
 
44
  class ModelManager:
 
 
 
 
 
 
45
  def get_llama_pipeline(self, model_id, token):
46
  if model_id not in self.llama_pipelines:
47
  tokenizer = AutoTokenizer.from_pretrained(
48
  model_id,
49
  token=token,
50
- legacy=False # Critical for tokenizers 0.19.x compatibility
51
  )
52
  model = AutoModelForCausalLM.from_pretrained(
53
  model_id,
54
  token=token,
55
  torch_dtype=torch.float16,
56
- device_map="auto"
 
57
  )
58
  self.llama_pipelines[model_id] = pipeline(
59
  "text-generation",
@@ -62,19 +69,15 @@ class ModelManager:
62
  device_map="auto"
63
  )
64
  return self.llama_pipelines[model_id]
65
-
66
  def get_musicgen_model(self):
67
  if not self.musicgen_model:
68
  self.musicgen_model = MusicgenForConditionalGeneration.from_pretrained(
69
  MODEL_CONFIG["musicgen_model"]
70
  )
 
71
  self.musicgen_model.to("cuda" if torch.cuda.is_available() else "cpu")
72
- return self.musicgen_model
73
-
74
- def get_tts_model(self, model_name):
75
- if model_name not in self.tts_models:
76
- self.tts_models[model_name] = TTS(model_name)
77
- return self.tts_models[model_name]
78
 
79
  model_manager = ModelManager()
80
 
 
42
  # -------------------------------
43
 
44
  class ModelManager:
45
+ def __init__(self):
46
+ self.llama_pipelines = {}
47
+ self.musicgen_model = None
48
+ self.tts_models = {}
49
+ self.processor = None # Add processor cache
50
+
51
  def get_llama_pipeline(self, model_id, token):
52
  if model_id not in self.llama_pipelines:
53
  tokenizer = AutoTokenizer.from_pretrained(
54
  model_id,
55
  token=token,
56
+ legacy=False
57
  )
58
  model = AutoModelForCausalLM.from_pretrained(
59
  model_id,
60
  token=token,
61
  torch_dtype=torch.float16,
62
+ device_map="auto",
63
+ low_cpu_mem_usage=True
64
  )
65
  self.llama_pipelines[model_id] = pipeline(
66
  "text-generation",
 
69
  device_map="auto"
70
  )
71
  return self.llama_pipelines[model_id]
72
+
73
  def get_musicgen_model(self):
74
  if not self.musicgen_model:
75
  self.musicgen_model = MusicgenForConditionalGeneration.from_pretrained(
76
  MODEL_CONFIG["musicgen_model"]
77
  )
78
+ self.processor = AutoProcessor.from_pretrained(MODEL_CONFIG["musicgen_model"])
79
  self.musicgen_model.to("cuda" if torch.cuda.is_available() else "cpu")
80
+ return self.musicgen_model, self.processor
 
 
 
 
 
81
 
82
  model_manager = ModelManager()
83