atharvasc27112001 commited on
Commit
43d8873
·
verified ·
1 Parent(s): 3180216

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import CLIPProcessor, CLIPModel, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
  import soundfile as sf
5
 
@@ -15,15 +15,16 @@ print("Loading Whisper model...")
15
  whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
16
  whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
17
 
18
- print("Loading GPT-2 model (placeholder for your text model)...")
19
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
20
- text_model = AutoModelForCausalLM.from_pretrained("gpt2")
21
 
22
  # ------------------------------
23
  # Define Projection Layers
24
  # ------------------------------
25
  print("Initializing image projection layer...")
26
- # Project CLIP's 512-dimensional image embeddings to GPT-2's 768-dimensional space.
 
27
  image_projection = torch.nn.Linear(512, 768)
28
 
29
  # ------------------------------
@@ -33,11 +34,11 @@ image_projection = torch.nn.Linear(512, 768)
33
  def multimodal_inference(text_input, image_input, audio_input):
34
  """
35
  Processes text, image, and audio inputs:
36
- - Text: used directly.
37
- - Image: processed via CLIP and projected (here, we append a placeholder tag).
38
- - Audio: transcribed using Whisper.
39
 
40
- The final prompt is fed to the text model (GPT-2) to generate a response.
41
  """
42
  prompt = ""
43
 
@@ -54,7 +55,7 @@ def multimodal_inference(text_input, image_input, audio_input):
54
  # Normalize and project image features
55
  image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
56
  projected_image = image_projection(image_features)
57
- # For demo purposes, we append a placeholder tag.
58
  prompt += " [IMAGE_EMBEDDING]"
59
  except Exception as e:
60
  print("Error processing image:", e)
@@ -79,16 +80,16 @@ def multimodal_inference(text_input, image_input, audio_input):
79
 
80
  print("Final fused prompt:", prompt)
81
 
82
- # Generate text response using the text model with advanced decoding parameters
83
  inputs = tokenizer(prompt, return_tensors="pt")
84
  with torch.no_grad():
85
  generated_ids = text_model.generate(
86
  **inputs,
87
  max_length=200,
88
- temperature=0.7, # Controls randomness (0=deterministic, 1=more random)
89
- top_p=0.9, # Limits sampling to the top 90% probability mass
90
- repetition_penalty=1.2,# Penalizes repeated phrases
91
- do_sample=True # Enables sampling (instead of greedy decoding)
92
  )
93
  generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
94
 
@@ -106,8 +107,8 @@ iface = gr.Interface(
106
  gr.Audio(type="filepath", label="Audio Input (Optional)")
107
  ],
108
  outputs="text",
109
- title="Multi-Modal LLM Demo",
110
- description="This demo accepts text, image, and audio inputs, processes each modality, and produces a text response."
111
  )
112
 
113
  if __name__ == "__main__":
 
1
  import torch
2
+ from transformers import CLIPProcessor, CLIPModel, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
3
  import gradio as gr
4
  import soundfile as sf
5
 
 
15
  whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
16
  whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
17
 
18
+ print("Loading Flan-T5 model (instruction-tuned for better responses)...")
19
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
20
+ text_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
21
 
22
  # ------------------------------
23
  # Define Projection Layers
24
  # ------------------------------
25
  print("Initializing image projection layer...")
26
+ # This linear layer projects CLIP's 512-dimensional image embeddings to Flan-T5's expected dimension.
27
+ # (For a real system, you would fine-tune this layer.)
28
  image_projection = torch.nn.Linear(512, 768)
29
 
30
  # ------------------------------
 
34
  def multimodal_inference(text_input, image_input, audio_input):
35
  """
36
  Processes text, image, and audio inputs:
37
+ - Text: is used directly.
38
+ - Image: is processed via CLIP; its embedding is projected and a placeholder is appended.
39
+ - Audio: is transcribed using Whisper.
40
 
41
+ The combined prompt is then fed into Flan-T5 to generate a text response.
42
  """
43
  prompt = ""
44
 
 
55
  # Normalize and project image features
56
  image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
57
  projected_image = image_projection(image_features)
58
+ # For this demo, we append a placeholder tag to indicate image information.
59
  prompt += " [IMAGE_EMBEDDING]"
60
  except Exception as e:
61
  print("Error processing image:", e)
 
80
 
81
  print("Final fused prompt:", prompt)
82
 
83
+ # Tokenize and generate text using Flan-T5
84
  inputs = tokenizer(prompt, return_tensors="pt")
85
  with torch.no_grad():
86
  generated_ids = text_model.generate(
87
  **inputs,
88
  max_length=200,
89
+ temperature=0.7, # Moderate randomness
90
+ top_p=0.9, # Nucleus sampling to limit token choices
91
+ repetition_penalty=1.2,# Penalize repeated tokens
92
+ do_sample=True # Enable sampling for more varied responses
93
  )
94
  generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
95
 
 
107
  gr.Audio(type="filepath", label="Audio Input (Optional)")
108
  ],
109
  outputs="text",
110
+ title="Multi-Modal LLM Demo with Flan-T5",
111
+ description="This demo accepts text, image, and audio inputs, processes each modality, and produces a text response using an instruction-tuned model."
112
  )
113
 
114
  if __name__ == "__main__":