import gradio as gr import torch import torchvision import torchvision.transforms as transforms import random import numpy as np from transformers import ( SiglipVisionModel, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM ) from peft import PeftModel from PIL import Image # Initialize device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load models and processors def load_models(): # Load SigLIP print("Loading SigLIP model...") siglip_model = SiglipVisionModel.from_pretrained( "google/siglip-so400m-patch14-384", torch_dtype=torch.float16, low_cpu_mem_usage=True ).to(device) siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384") # Load base Phi-3 model print("Loading Phi-3 model...") base_model = AutoModelForCausalLM.from_pretrained( "microsoft/phi-2", torch_dtype=torch.float16, low_cpu_mem_usage=True ).to(device) # Load the trained LoRA weights print("Loading trained LoRA weights...") phi_model = PeftModel.from_pretrained( base_model, "phi_model_trained" ) phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") if phi_tokenizer.pad_token is None: phi_tokenizer.pad_token = phi_tokenizer.eos_token # Load trained projections print("Loading projection layers...") linear_proj = torch.load('linear_projection_final.pth', map_location=device) image_text_proj = torch.load('image_text_proj.pth', map_location=device) return (siglip_model, siglip_processor, phi_model, phi_tokenizer, linear_proj, image_text_proj) # Load all models at startup print("Loading models...") models = load_models() siglip_model, siglip_processor, phi_model, phi_tokenizer, linear_proj, image_text_proj = models print("Models loaded successfully!") # Load CIFAR10 test dataset transform = transforms.Compose([ transforms.Resize((384, 384)), transforms.ToTensor(), ]) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) # Get first 100 images first_100_images = [(images, labels) for images, labels in list(testset)[:100]] # Questions list questions = [ "Give a description of the image?", "How does the main object in the image look like?", "How can the main object in the image be useful to humans?", "What is the color of the main object in the image?", "Describe the setting of the image?" ] def get_image_embedding(image, siglip_model, siglip_processor, linear_proj, device): with torch.no_grad(): # Process image through SigLIP inputs = siglip_processor(image, return_tensors="pt") inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} outputs = siglip_model(**inputs) image_features = outputs.pooler_output projected_features = linear_proj(image_features) return projected_features def get_random_images(): # Select 10 random images from first 100 selected_indices = random.sample(range(100), 10) selected_images = [first_100_images[i][0] for i in selected_indices] # Convert to numpy arrays and transpose to correct format (H,W,C) images_np = [img.permute(1, 2, 0).numpy() for img in selected_images] return images_np, selected_indices def generate_answer(image_tensor, question_index): if image_tensor is None: return "Please select an image first!" try: # Get image embedding image_embedding = get_image_embedding( image_tensor, siglip_model, siglip_processor, linear_proj, device ) # Get question question = questions[question_index] # Tokenize question question_tokens = phi_tokenizer( question, padding=True, truncation=True, max_length=512, return_tensors="pt" ).to(device) # Get question embeddings question_embeds = phi_model.get_input_embeddings()(question_tokens['input_ids']) # Project and prepare image embeddings image_embeds = image_text_proj(image_embedding) image_embeds = image_embeds.unsqueeze(1) # Combine embeddings combined_embedding = torch.cat([ image_embeds, question_embeds ], dim=1) # Create attention mask attention_mask = torch.ones( (1, combined_embedding.size(1)), dtype=torch.long, device=device ) # Generate answer with torch.no_grad(): outputs = phi_model.generate( inputs_embeds=combined_embedding, attention_mask=attention_mask, max_new_tokens=100, num_beams=4, temperature=0.7, do_sample=True, pad_token_id=phi_tokenizer.pad_token_id, eos_token_id=phi_tokenizer.eos_token_id ) # Decode the generated answer answer = phi_tokenizer.decode(outputs[0], skip_special_tokens=True) return answer except Exception as e: return f"Error generating answer: {str(e)}" # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# CIFAR10 Image Question Answering System") # State variables selected_image_tensor = gr.State(None) image_indices = gr.State([]) with gr.Row(): with gr.Column(): random_btn = gr.Button("Get Random Images") gallery = gr.Gallery( label="Click an image to select it", show_label=True, elem_id="gallery", columns=[5], rows=[2], height="auto", allow_preview=False ) with gr.Column(): selected_img = gr.Image(label="Selected Image", height=200) q_buttons = [] for i, q in enumerate(questions): btn = gr.Button(f"Q{i+1}: {q}") q_buttons.append(btn) answer_box = gr.Textbox(label="Answer", lines=3) def on_random_click(): images, indices = get_random_images() return { gallery: images, image_indices: indices, selected_image_tensor: None, selected_img: None, answer_box: "" } random_btn.click( on_random_click, outputs=[gallery, image_indices, selected_image_tensor, selected_img, answer_box] ) def on_image_select(evt: gr.SelectData, images, indices): if images is None or evt.index >= len(images): return None, None, "" selected_idx = indices[evt.index] selected_tensor = first_100_images[selected_idx][0] return selected_tensor, images[evt.index], "" gallery.select( on_image_select, inputs=[gallery, image_indices], outputs=[selected_image_tensor, selected_img, answer_box] ) for i, btn in enumerate(q_buttons): btn.click( generate_answer, inputs=[selected_image_tensor, gr.Number(value=i, visible=False)], outputs=answer_box ) # Launch with minimal settings demo.queue(max_size=1).launch(show_error=True)