RishabA's picture
Upload 4 files
10c688c verified
raw
history blame
2.94 kB
import torch
from torchvision import transforms
from PIL import Image
import gradio as gr
from transformers import AutoTokenizer
from model import CaptioningTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 128
patch_size = 8
d_model = 192
n_layers = 6
n_heads = 8
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Instantiate your model
model = CaptioningTransformer(
image_size=image_size,
in_channels=3, # RGB images
vocab_size=tokenizer.vocab_size,
device=device,
patch_size=patch_size,
n_layers=n_layers,
d_model=d_model,
n_heads=n_heads,
).to(device)
# Load your pre-trained weights (make sure the .pt file is in your repo)
model_path = "image_captioning_model.pt"
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# This is your existing inference function (you can modify as needed)
def make_prediction(model, sos_token, eos_token, image, max_len, temp, device):
log_tokens = [sos_token] # Start with the start-of-sequence token
with torch.inference_mode():
# Get image embeddings from the encoder
image_embedding = model.encoder(image.to(device))
for _ in range(max_len):
input_tokens = torch.cat(log_tokens, dim=1)
data_pred = model.decoder(input_tokens.to(device), image_embedding)
# Get the logits for the most recent token only
dist = torch.distributions.Categorical(logits=data_pred[:, -1] / temp)
next_tokens = dist.sample().reshape(1, 1)
log_tokens.append(next_tokens.cpu())
if next_tokens.item() == 102: # Assuming 102 is your [SEP] token
break
return torch.cat(log_tokens, dim=1)
# Define the Gradio prediction function
def predict(image: Image.Image):
# Preprocess the image
img_tensor = transform(image).unsqueeze(0) # Shape: (1, 3, image_size, image_size)
# Create a start-of-sequence token (assuming 101 is your [CLS] token)
sos_token = 101 * torch.ones(1, 1).long().to(device)
# Generate caption tokens using your inference function
tokens = make_prediction(
model, sos_token, 102, img_tensor, max_len=50, temp=0.5, device=device
)
# Decode tokens to text (skipping special tokens)
caption = tokenizer.decode(tokens[0], skip_special_tokens=True)
return caption
# Create a Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="Image Captioning Model",
description="Upload an image and get a caption generated by the model.",
)
if __name__ == "__main__":
iface.launch()