preston-cell commited on
Commit
f4f3543
·
verified ·
1 Parent(s): a483c36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -5,7 +5,6 @@ from transformers import (
5
  AutoModelForCausalLM,
6
  AutoTokenizer,
7
  GenerationConfig,
8
- TextStreamer,
9
  set_seed
10
  )
11
  from datasets import load_dataset
@@ -44,16 +43,22 @@ doge_generation_config = GenerationConfig(
44
  repetition_penalty=1.0
45
  )
46
 
47
- # Load speaker embedding with exactly 600 values
48
  speaker_embedding = None
49
  embedding_data = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
 
50
  for entry in embedding_data:
51
  vec = entry["xvector"]
52
  if len(vec) >= 600:
53
- speaker_embedding = torch.tensor(vec[:600], dtype=torch.float32).unsqueeze(0) # Shape: [1, 600]
54
  break
 
 
55
  if speaker_embedding is None:
56
- raise ValueError("No suitable speaker embedding of at least 600 dimensions found.")
 
 
 
57
  assert speaker_embedding.shape == (1, 600), f"Expected shape (1, 600), got {speaker_embedding.shape}"
58
 
59
 
@@ -75,7 +80,7 @@ def process_image(image):
75
 
76
  # 3. Prompt Doge model for context generation
77
  prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:"
78
- prompt = prompt[:600] # Ensure prompt isn't too long
79
  conversation = [{"role": "user", "content": prompt}]
80
  doge_inputs = doge_tokenizer.apply_chat_template(
81
  conversation=conversation,
 
5
  AutoModelForCausalLM,
6
  AutoTokenizer,
7
  GenerationConfig,
 
8
  set_seed
9
  )
10
  from datasets import load_dataset
 
43
  repetition_penalty=1.0
44
  )
45
 
46
+ # Load speaker embedding with fallback
47
  speaker_embedding = None
48
  embedding_data = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
49
+
50
  for entry in embedding_data:
51
  vec = entry["xvector"]
52
  if len(vec) >= 600:
53
+ speaker_embedding = torch.tensor(vec[:600], dtype=torch.float32).unsqueeze(0)
54
  break
55
+
56
+ # Fallback: use a zero vector if none found
57
  if speaker_embedding is None:
58
+ print("⚠️ No suitable speaker embedding found. Using default 600-dim zero vector.")
59
+ speaker_embedding = torch.zeros(1, 600, dtype=torch.float32)
60
+
61
+ # Ensure correct shape
62
  assert speaker_embedding.shape == (1, 600), f"Expected shape (1, 600), got {speaker_embedding.shape}"
63
 
64
 
 
80
 
81
  # 3. Prompt Doge model for context generation
82
  prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:"
83
+ prompt = prompt[:600] # Prevent tensor mismatch error
84
  conversation = [{"role": "user", "content": prompt}]
85
  doge_inputs = doge_tokenizer.apply_chat_template(
86
  conversation=conversation,