krishnv commited on
Commit
9633d94
·
verified ·
1 Parent(s): 31e8f8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -18
app.py CHANGED
@@ -1,41 +1,38 @@
 
1
  from PIL import Image
2
- from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, PreTrainedTokenizerFast
3
  import gradio as gr
4
 
5
- # Load the model and processor
6
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/git-base")
7
- feature_extractor = ViTFeatureExtractor.from_pretrained("microsoft/git-base")
8
- tokenizer = PreTrainedTokenizerFast.from_pretrained("microsoft/git-base")
9
 
10
  # Define the captioning function
11
- def caption_images(image):
12
- # Preprocess the image
13
- pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
14
  # Generate captions
15
- encoder_outputs = model.generate(pixel_values.to('cpu'), num_beams=5)
16
- generated_sentence = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
17
- return generated_sentence[0].strip()
18
 
19
  # Define Gradio interface components
20
  inputs = [
21
- gr.inputs.Image(type='pil', label='Original Image')
22
  ]
23
 
24
  outputs = [
25
- gr.outputs.Textbox(label='Caption')
26
  ]
27
 
28
  # Define Gradio app properties
29
- title = "Simple Image Captioning Application"
30
- description = "Upload an image to see the caption generated"
31
- example = ['messi.jpg'] # Replace with a valid path to an example image
32
 
33
  # Create and launch the Gradio interface
34
  gr.Interface(
35
- fn=caption_images,
36
  inputs=inputs,
37
  outputs=outputs,
38
  title=title,
39
  description=description,
40
- examples=example,
41
  ).launch(debug=True)
 
1
+ from transformers import AutoProcessor, AutoModelForCausalLM
2
  from PIL import Image
 
3
  import gradio as gr
4
 
5
+ # Load the processor and model
6
+ processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
7
+ model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
 
8
 
9
  # Define the captioning function
10
+ def caption_image(image):
11
+ # Process the image
12
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
13
  # Generate captions
14
+ generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
15
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
16
+ return generated_caption
17
 
18
  # Define Gradio interface components
19
  inputs = [
20
+ gr.inputs.Image(type='pil', label='Upload Image')
21
  ]
22
 
23
  outputs = [
24
+ gr.outputs.Textbox(label='Generated Caption')
25
  ]
26
 
27
  # Define Gradio app properties
28
+ title = "Image Captioning Application"
29
+ description = "Upload an image to see the caption generated by the model"
 
30
 
31
  # Create and launch the Gradio interface
32
  gr.Interface(
33
+ fn=caption_image,
34
  inputs=inputs,
35
  outputs=outputs,
36
  title=title,
37
  description=description,
 
38
  ).launch(debug=True)