RishabA commited on
Commit
2978074
·
verified ·
1 Parent(s): 42296ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -28
app.py CHANGED
@@ -13,7 +13,6 @@ n_layers = 6
13
  n_heads = 8
14
 
15
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
16
-
17
  transform = transforms.Compose(
18
  [
19
  transforms.Resize(image_size),
@@ -23,10 +22,9 @@ transform = transforms.Compose(
23
  ]
24
  )
25
 
26
- # Instantiate your model
27
  model = CaptioningTransformer(
28
  image_size=image_size,
29
- in_channels=3, # RGB images
30
  vocab_size=tokenizer.vocab_size,
31
  device=device,
32
  patch_size=patch_size,
@@ -35,53 +33,51 @@ model = CaptioningTransformer(
35
  n_heads=n_heads,
36
  ).to(device)
37
 
38
- # Load your pre-trained weights (make sure the .pt file is in your repo)
39
  model_path = "image_captioning_model.pt"
40
  model.load_state_dict(torch.load(model_path, map_location=device))
41
  model.eval()
42
 
43
 
44
- # This is your existing inference function (you can modify as needed)
45
- def make_prediction(model, sos_token, eos_token, image, max_len, temp, device):
46
- log_tokens = [sos_token] # Start with the start-of-sequence token
 
47
  with torch.inference_mode():
48
- # Get image embeddings from the encoder
49
  image_embedding = model.encoder(image.to(device))
50
  for _ in range(max_len):
51
  input_tokens = torch.cat(log_tokens, dim=1)
52
  data_pred = model.decoder(input_tokens.to(device), image_embedding)
53
- # Get the logits for the most recent token only
54
  dist = torch.distributions.Categorical(logits=data_pred[:, -1] / temp)
55
  next_tokens = dist.sample().reshape(1, 1)
56
  log_tokens.append(next_tokens.cpu())
57
- if next_tokens.item() == 102: # Assuming 102 is your [SEP] token
58
  break
59
  return torch.cat(log_tokens, dim=1)
60
 
61
 
62
- # Define the Gradio prediction function
63
  def predict(image: Image.Image):
64
- # Preprocess the image
65
- img_tensor = transform(image).unsqueeze(0) # Shape: (1, 3, image_size, image_size)
66
- # Create a start-of-sequence token (assuming 101 is your [CLS] token)
67
  sos_token = 101 * torch.ones(1, 1).long().to(device)
68
- # Generate caption tokens using your inference function
69
- tokens = make_prediction(
70
- model, sos_token, 102, img_tensor, max_len=50, temp=0.5, device=device
71
- )
72
- # Decode tokens to text (skipping special tokens)
73
  caption = tokenizer.decode(tokens[0], skip_special_tokens=True)
74
  return caption
75
 
76
 
77
- # Create a Gradio interface
78
- iface = gr.Interface(
79
- fn=predict,
80
- inputs=gr.Image(type="pil"),
81
- outputs="text",
82
- title="Image Captioning Model",
83
- description="Upload an image and get a caption generated by the model.",
84
- )
 
 
 
 
 
 
 
85
 
86
  if __name__ == "__main__":
87
- iface.launch()
 
13
  n_heads = 8
14
 
15
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
 
16
  transform = transforms.Compose(
17
  [
18
  transforms.Resize(image_size),
 
22
  ]
23
  )
24
 
 
25
  model = CaptioningTransformer(
26
  image_size=image_size,
27
+ in_channels=3,
28
  vocab_size=tokenizer.vocab_size,
29
  device=device,
30
  patch_size=patch_size,
 
33
  n_heads=n_heads,
34
  ).to(device)
35
 
 
36
  model_path = "image_captioning_model.pt"
37
  model.load_state_dict(torch.load(model_path, map_location=device))
38
  model.eval()
39
 
40
 
41
+ def make_prediction(
42
+ model, sos_token, eos_token, image, max_len=50, temp=0.5, device=device
43
+ ):
44
+ log_tokens = [sos_token]
45
  with torch.inference_mode():
 
46
  image_embedding = model.encoder(image.to(device))
47
  for _ in range(max_len):
48
  input_tokens = torch.cat(log_tokens, dim=1)
49
  data_pred = model.decoder(input_tokens.to(device), image_embedding)
 
50
  dist = torch.distributions.Categorical(logits=data_pred[:, -1] / temp)
51
  next_tokens = dist.sample().reshape(1, 1)
52
  log_tokens.append(next_tokens.cpu())
53
+ if next_tokens.item() == 102:
54
  break
55
  return torch.cat(log_tokens, dim=1)
56
 
57
 
 
58
  def predict(image: Image.Image):
59
+ img_tensor = transform(image).unsqueeze(0)
 
 
60
  sos_token = 101 * torch.ones(1, 1).long().to(device)
61
+ tokens = make_prediction(model, sos_token, 102, img_tensor)
 
 
 
 
62
  caption = tokenizer.decode(tokens[0], skip_special_tokens=True)
63
  return caption
64
 
65
 
66
+ with gr.Blocks(css=".block-title { font-size: 24px; font-weight: bold; }") as demo:
67
+ gr.Markdown("<div class='block-title'>Image Captioning with PyTorch</div>")
68
+ gr.Markdown("Upload an image and get a descriptive caption about the image:")
69
+
70
+ with gr.Row():
71
+ with gr.Column():
72
+ image_input = gr.Image(type="pil", label="Your Image")
73
+ generate_button = gr.Button("Generate Caption")
74
+ with gr.Column():
75
+ caption_output = gr.Textbox(
76
+ label="Caption Output",
77
+ placeholder="Your generated caption will appear here...",
78
+ )
79
+
80
+ generate_button.click(fn=predict, inputs=image_input, outputs=caption_output)
81
 
82
  if __name__ == "__main__":
83
+ demo.launch(share=True)