import torch from transformers import CLIPProcessor, CLIPModel, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr import soundfile as sf # ------------------------------ # Load Pretrained Models & Processors # ------------------------------ print("Loading CLIP model...") clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") print("Loading Whisper model...") whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") print("Loading Flan-T5 model (instruction-tuned for better responses)...") tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") text_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large") # ------------------------------ # Define Projection Layers # ------------------------------ print("Initializing image projection layer...") # This linear layer projects CLIP's 512-dimensional image embeddings to Flan-T5's expected dimension. # (For a real system, you would fine-tune this layer.) image_projection = torch.nn.Linear(512, 768) # ------------------------------ # Multi-Modal Inference Function # ------------------------------ def multimodal_inference(text_input, image_input, audio_input): """ Processes text, image, and audio inputs: - Text: is used directly. - Image: is processed via CLIP; its embedding is projected and a placeholder is appended. - Audio: is transcribed using Whisper. The combined prompt is then fed into Flan-T5 to generate a text response. """ prompt = "" # Process text input if text_input: prompt += text_input.strip() # Process image input if provided if image_input is not None: try: clip_inputs = clip_processor(images=image_input, return_tensors="pt") with torch.no_grad(): image_features = clip_model.get_image_features(**clip_inputs) # Normalize and project image features image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) projected_image = image_projection(image_features) # For this demo, we append a placeholder tag to indicate image information. prompt += " [IMAGE_EMBEDDING]" except Exception as e: print("Error processing image:", e) prompt += " [IMAGE_ERROR]" # Process audio input if provided if audio_input is not None: try: audio, sr = sf.read(audio_input) except Exception as e: print("Error reading audio file:", e) return "Error processing audio input." try: whisper_inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt") with torch.no_grad(): predicted_ids = whisper_model.generate(whisper_inputs.input_features) transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] prompt += " " + transcription.strip() except Exception as e: print("Error during audio transcription:", e) prompt += " [AUDIO_ERROR]" print("Final fused prompt:", prompt) # Tokenize and generate text using Flan-T5 inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): generated_ids = text_model.generate( **inputs, max_length=200, temperature=0.7, # Moderate randomness top_p=0.9, # Nucleus sampling to limit token choices repetition_penalty=1.2,# Penalize repeated tokens do_sample=True # Enable sampling for more varied responses ) generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) return generated_text # ------------------------------ # Gradio Interface for Hugging Face Spaces # ------------------------------ iface = gr.Interface( fn=multimodal_inference, inputs=[ gr.Textbox(lines=5, placeholder="Enter your text here...", label="Text Input"), gr.Image(type="pil", label="Image Input (Optional)"), gr.Audio(type="filepath", label="Audio Input (Optional)") ], outputs="text", title="Multi-Modal LLM Demo with Flan-T5", description="This demo accepts text, image, and audio inputs, processes each modality, and produces a text response using an instruction-tuned model." ) if __name__ == "__main__": iface.launch()