BarBar288 commited on
Commit
c5a9717
·
verified ·
1 Parent(s): 11fc91a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -16,6 +16,10 @@ if not read_token:
16
  from huggingface_hub import login
17
  login(read_token)
18
 
 
 
 
 
19
  # Define a dictionary of conversational models
20
  conversational_models = {
21
  "Qwen": "Qwen/QwQ-32B",
@@ -49,12 +53,12 @@ text_to_image_pipelines = {}
49
  text_to_speech_pipelines = {}
50
 
51
  # Initialize pipelines for other tasks
52
- visual_qa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
53
- document_qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
54
- image_classification_pipeline = pipeline("image-classification", model="facebook/deit-base-distilled-patch16-224")
55
- object_detection_pipeline = pipeline("object-detection", model="facebook/detr-resnet-50")
56
- video_classification_pipeline = pipeline("video-classification", model="facebook/timesformer-base-finetuned-k400")
57
- summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
58
 
59
  # Load speaker embeddings for text-to-audio
60
  def load_speaker_embeddings(model_name):
@@ -62,26 +66,26 @@ def load_speaker_embeddings(model_name):
62
  logger.info("Loading speaker embeddings for SpeechT5")
63
  from datasets import load_dataset
64
  dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
65
- speaker_embeddings = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0) # Example speaker
66
  return speaker_embeddings
67
  return None
68
 
69
  # Use a different model for text-to-audio if stabilityai/stable-audio-open-1.0 is not supported
70
  try:
71
- text_to_audio_pipeline = pipeline("text-to-audio", model="stabilityai/stable-audio-open-1.0", use_auth_token=read_token)
72
  except ValueError as e:
73
  logger.error(f"Error loading stabilityai/stable-audio-open-1.0: {e}")
74
  logger.info("Falling back to a different text-to-audio model.")
75
- text_to_audio_pipeline = pipeline("text-to-audio", model="microsoft/speecht5_tts")
76
  speaker_embeddings = load_speaker_embeddings("microsoft/speecht5_tts")
77
 
78
- audio_classification_pipeline = pipeline("audio-classification", model="facebook/wav2vec2-base")
79
 
80
  def load_conversational_model(model_name):
81
  if model_name not in conversational_models_loaded:
82
  logger.info(f"Loading conversational model: {model_name}")
83
  tokenizer = AutoTokenizer.from_pretrained(conversational_models[model_name], use_auth_token=read_token)
84
- model = AutoModelForCausalLM.from_pretrained(conversational_models[model_name], use_auth_token=read_token)
85
  conversational_tokenizers[model_name] = tokenizer
86
  conversational_models_loaded[model_name] = model
87
  return conversational_tokenizers[model_name], conversational_models_loaded[model_name]
@@ -90,7 +94,7 @@ def chat(model_name, user_input, history=[]):
90
  tokenizer, model = load_conversational_model(model_name)
91
 
92
  # Encode the input
93
- input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
94
 
95
  # Generate a response
96
  with torch.no_grad():
@@ -110,7 +114,7 @@ def generate_image(model_name, prompt):
110
  if model_name not in text_to_image_pipelines:
111
  logger.info(f"Loading text-to-image model: {model_name}")
112
  text_to_image_pipelines[model_name] = StableDiffusionPipeline.from_pretrained(
113
- text_to_image_models[model_name], use_auth_token=read_token
114
  )
115
  pipeline = text_to_image_pipelines[model_name]
116
  image = pipeline(prompt).images[0]
@@ -120,12 +124,14 @@ def generate_speech(model_name, text):
120
  if model_name not in text_to_speech_pipelines:
121
  logger.info(f"Loading text-to-speech model: {model_name}")
122
  text_to_speech_pipelines[model_name] = pipeline(
123
- "text-to-speech", model=text_to_speech_models[model_name], use_auth_token=read_token
124
  )
125
  pipeline = text_to_speech_pipelines[model_name]
126
- audio = pipeline(text)
127
  return audio["audio"]
128
 
 
 
129
  def visual_qa(image, question):
130
  result = visual_qa_pipeline(image, question)
131
  return result["answer"]
 
16
  from huggingface_hub import login
17
  login(read_token)
18
 
19
+ # Set device to GPU if available
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ logger.info(f"Device set to use {device}")
22
+
23
  # Define a dictionary of conversational models
24
  conversational_models = {
25
  "Qwen": "Qwen/QwQ-32B",
 
53
  text_to_speech_pipelines = {}
54
 
55
  # Initialize pipelines for other tasks
56
+ visual_qa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa", device=device)
57
+ document_qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2", device=device)
58
+ image_classification_pipeline = pipeline("image-classification", model="facebook/deit-base-distilled-patch16-224", device=device)
59
+ object_detection_pipeline = pipeline("object-detection", model="facebook/detr-resnet-50", device=device)
60
+ video_classification_pipeline = pipeline("video-classification", model="facebook/timesformer-base-finetuned-k400", device=device)
61
+ summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
62
 
63
  # Load speaker embeddings for text-to-audio
64
  def load_speaker_embeddings(model_name):
 
66
  logger.info("Loading speaker embeddings for SpeechT5")
67
  from datasets import load_dataset
68
  dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
69
+ speaker_embeddings = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0).to(device) # Example speaker
70
  return speaker_embeddings
71
  return None
72
 
73
  # Use a different model for text-to-audio if stabilityai/stable-audio-open-1.0 is not supported
74
  try:
75
+ text_to_audio_pipeline = pipeline("text-to-audio", model="stabilityai/stable-audio-open-1.0", use_auth_token=read_token, device=device)
76
  except ValueError as e:
77
  logger.error(f"Error loading stabilityai/stable-audio-open-1.0: {e}")
78
  logger.info("Falling back to a different text-to-audio model.")
79
+ text_to_audio_pipeline = pipeline("text-to-audio", model="microsoft/speecht5_tts", use_auth_token=read_token, device=device)
80
  speaker_embeddings = load_speaker_embeddings("microsoft/speecht5_tts")
81
 
82
+ audio_classification_pipeline = pipeline("audio-classification", model="facebook/wav2vec2-base", device=device)
83
 
84
  def load_conversational_model(model_name):
85
  if model_name not in conversational_models_loaded:
86
  logger.info(f"Loading conversational model: {model_name}")
87
  tokenizer = AutoTokenizer.from_pretrained(conversational_models[model_name], use_auth_token=read_token)
88
+ model = AutoModelForCausalLM.from_pretrained(conversational_models[model_name], use_auth_token=read_token).to(device)
89
  conversational_tokenizers[model_name] = tokenizer
90
  conversational_models_loaded[model_name] = model
91
  return conversational_tokenizers[model_name], conversational_models_loaded[model_name]
 
94
  tokenizer, model = load_conversational_model(model_name)
95
 
96
  # Encode the input
97
+ input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt").to(device)
98
 
99
  # Generate a response
100
  with torch.no_grad():
 
114
  if model_name not in text_to_image_pipelines:
115
  logger.info(f"Loading text-to-image model: {model_name}")
116
  text_to_image_pipelines[model_name] = StableDiffusionPipeline.from_pretrained(
117
+ text_to_image_models[model_name], use_auth_token=read_token, torch_dtype=torch.float16, device_map="auto"
118
  )
119
  pipeline = text_to_image_pipelines[model_name]
120
  image = pipeline(prompt).images[0]
 
124
  if model_name not in text_to_speech_pipelines:
125
  logger.info(f"Loading text-to-speech model: {model_name}")
126
  text_to_speech_pipelines[model_name] = pipeline(
127
+ "text-to-speech", model=text_to_speech_models[model_name], use_auth_token=read_token, device=device
128
  )
129
  pipeline = text_to_speech_pipelines[model_name]
130
+ audio = pipeline(text, speaker_embeddings=speaker_embeddings)
131
  return audio["audio"]
132
 
133
+
134
+
135
  def visual_qa(image, question):
136
  result = visual_qa_pipeline(image, question)
137
  return result["answer"]