miculpionier commited on
Commit
1c4ba29
·
1 Parent(s): c61a3a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -1,13 +1,20 @@
1
  import gradio
2
  from transformers import ViltProcessor, ViltForQuestionAnswering
3
  from PIL import Image
 
4
 
5
  processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
6
  model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
7
 
8
 
9
  def predict_answer(image, question):
10
- image = Image.fromarray(image.astype('uint8'), 'RGB')
 
 
 
 
 
 
11
  encoding = processor(image, question, return_tensors="pt")
12
  outputs = model(**encoding)
13
  logits = outputs.logits
@@ -24,7 +31,8 @@ def predict_answer(image, question):
24
 
25
 
26
  inputs = [
27
- gradio.components.Image(label="Image"),
 
28
  gradio.components.Textbox(label="Question", placeholder="Enter your question here.")
29
  ]
30
 
 
1
  import gradio
2
  from transformers import ViltProcessor, ViltForQuestionAnswering
3
  from PIL import Image
4
+ import requests
5
 
6
  processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
7
  model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
8
 
9
 
10
  def predict_answer(image, question):
11
+ if isinstance(image, str):
12
+ # Download image from URL
13
+ response = requests.get(image)
14
+ image = Image.open(BytesIO(response.content))
15
+ else:
16
+ # Convert numpy array to PIL Image
17
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
18
  encoding = processor(image, question, return_tensors="pt")
19
  outputs = model(**encoding)
20
  logits = outputs.logits
 
31
 
32
 
33
  inputs = [
34
+ gradio.inputs.Image(type="file", label="Upload Image"),
35
+ gradio.inputs.Image(type="url", label="Image URL"),
36
  gradio.components.Textbox(label="Question", placeholder="Enter your question here.")
37
  ]
38