aryan083 commited on
Commit
62d1efe
·
verified ·
1 Parent(s): 50e29d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -44
app.py CHANGED
@@ -1,45 +1,45 @@
1
- from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
2
- import torch
3
- from PIL import Image
4
- import gradio as gr
5
-
6
- model_name = "aryan083/vit-gpt2-image-captioning"
7
- model = VisionEncoderDecoderModel.from_pretrained(model_name)
8
- feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
10
-
11
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
- model.to(device)
13
-
14
- def predict_caption(image):
15
- if image is None:
16
- return None
17
-
18
- images = []
19
- images.append(image)
20
-
21
- pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
22
- pixel_values = pixel_values.to(device)
23
-
24
- output_ids = model.generate(
25
- pixel_values,
26
- do_sample=True,
27
- max_length=16,
28
- num_beams=4,
29
- temperature=0.7
30
- )
31
-
32
- preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
33
- return preds[0].strip()
34
-
35
- # Create Gradio interface
36
- iface = gr.Interface(
37
- fn=predict_caption,
38
- inputs=gr.Image(type="pil"),
39
- outputs=gr.Textbox(label="Generated Caption"),
40
- title="Image Captioning",
41
- description="Upload an image and get its description generated using ViT-GPT2",
42
- # examples=[["assets/example1.jpg"]] # Add example images if you have any
43
- )
44
-
45
  iface.launch()
 
1
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
2
+ import torch
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ model_name = "aryan083/vit-gpt2-image-captioning"
7
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
8
+ feature_extractor = ViTImageProcessor.from_pretrained(model_name) # Changed from ViTFeatureExtractor to ViTImageProcessor
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ model.to(device)
13
+
14
+ def predict_caption(image):
15
+ if image is None:
16
+ return None
17
+
18
+ images = []
19
+ images.append(image)
20
+
21
+ pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
22
+ pixel_values = pixel_values.to(device)
23
+
24
+ output_ids = model.generate(
25
+ pixel_values,
26
+ do_sample=True,
27
+ max_length=16,
28
+ num_beams=4,
29
+ temperature=0.7
30
+ )
31
+
32
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
33
+ return preds[0].strip()
34
+
35
+ # Create Gradio interface
36
+ iface = gr.Interface(
37
+ fn=predict_caption,
38
+ inputs=gr.Image(type="pil"),
39
+ outputs=gr.Textbox(label="Generated Caption"),
40
+ title="Image Captioning",
41
+ description="Upload an image and get its description generated using ViT-GPT2",
42
+ # examples=[["assets/example1.jpg"]] # Add example images if you have any
43
+ )
44
+
45
  iface.launch()