user-agent's picture
Update app.py
4c6f845 verified
raw
history blame contribute delete
1.61 kB
import spaces
import requests
from PIL import Image
from io import BytesIO
import torch
from transformers import CLIPProcessor, CLIPModel
import gradio as gr
# Initialize the model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
@spaces.GPU # Use the GPU decorator for the function that requires GPU
def get_embedding(image_or_text):
# Define device within the function to ensure it uses the GPU when available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
if image_or_text.startswith(('http:', 'https:')):
# Image URL
response = requests.get(image_or_text)
image = Image.open(BytesIO(response.content))
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
features = model.get_image_features(**inputs).cpu().numpy()
else:
# Text input
inputs = processor(text=[image_or_text], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
features = model.get_text_features(**inputs).cpu().numpy()
return features.flatten().tolist()
# Define the Gradio interface
interface = gr.Interface(fn=get_embedding,
inputs="text",
outputs="json",
title="CLIP Model Embeddings",
description="Enter an Image URL or text to get embeddings from CLIP.")
if __name__ == "__main__":
interface.launch(share=True)