maahin commited on
Commit
cc851d6
Β·
verified Β·
1 Parent(s): 2379c94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -24
app.py CHANGED
@@ -2,9 +2,9 @@ import os
2
  import streamlit as st
3
  from PIL import Image
4
  import torch
5
- from transformers import AutoProcessor, AutoModelForVision2Seq
6
 
7
- # Get Hugging Face API key from Hugging Face Spaces secrets
8
  HF_TOKEN = os.getenv("HF_KEY")
9
 
10
  # Ensure API key is available
@@ -12,12 +12,12 @@ if not HF_TOKEN:
12
  st.error("❌ Hugging Face API key not found! Set it as 'HF_KEY' in Spaces secrets.")
13
  st.stop()
14
 
15
- # Load the PaliGemma model and processor
16
  @st.cache_resource
17
  def load_model():
18
- model_name = "google/paligemma2-3b-mix-224"
19
- processor = AutoProcessor.from_pretrained(model_name, token=HF_TOKEN)
20
- model = AutoModelForVision2Seq.from_pretrained(model_name, token=HF_TOKEN)
21
  return processor, model
22
 
23
  processor, model = load_model()
@@ -31,33 +31,23 @@ if uploaded_file:
31
  image = Image.open(uploaded_file).convert("RGB")
32
  st.image(image, caption="Uploaded Image", use_container_width=True)
33
 
34
- # User selects the task
35
  task = st.selectbox(
36
  "Select a task:",
37
  ["Generate a caption", "Answer a question", "Detect objects", "Generate segmentation"]
38
  )
39
 
40
- # User input for question/prompt
41
  prompt = st.text_area("Enter a prompt (e.g., 'Describe the image' or 'What objects are present?')")
42
 
43
  if st.button("Run"):
44
  if prompt:
45
- inputs = processor(text=prompt, images=image, return_tensors="pt")
 
46
 
47
- with torch.no_grad():
48
- output = model.generate(**inputs)
49
-
50
- raw_output = processor.batch_decode(output, skip_special_tokens=False)[0]
51
-
52
- # Handle different outputs
53
- if task == "Generate a caption":
54
- answer = raw_output
55
- elif task == "Answer a question":
56
- answer = raw_output
57
- elif task == "Detect objects":
58
- answer = f"Object bounding boxes: {raw_output}"
59
- elif task == "Generate segmentation":
60
- answer = f"Segmentation codes: {raw_output}"
61
 
62
  st.success(f"βœ… Result: {answer}")
63
-
 
2
  import streamlit as st
3
  from PIL import Image
4
  import torch
5
+ from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
6
 
7
+ # Get Hugging Face API key from environment variables
8
  HF_TOKEN = os.getenv("HF_KEY")
9
 
10
  # Ensure API key is available
 
12
  st.error("❌ Hugging Face API key not found! Set it as 'HF_KEY' in Spaces secrets.")
13
  st.stop()
14
 
15
+ # Load the model and processor
16
  @st.cache_resource
17
  def load_model():
18
+ model_id = "google/paligemma2-3b-mix-224"
19
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").eval()
20
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
21
  return processor, model
22
 
23
  processor, model = load_model()
 
31
  image = Image.open(uploaded_file).convert("RGB")
32
  st.image(image, caption="Uploaded Image", use_container_width=True)
33
 
34
+ # User input for task selection
35
  task = st.selectbox(
36
  "Select a task:",
37
  ["Generate a caption", "Answer a question", "Detect objects", "Generate segmentation"]
38
  )
39
 
40
+ # User prompt
41
  prompt = st.text_area("Enter a prompt (e.g., 'Describe the image' or 'What objects are present?')")
42
 
43
  if st.button("Run"):
44
  if prompt:
45
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
46
+ input_len = inputs["input_ids"].shape[-1] # Get input length
47
 
48
+ with torch.inference_mode():
49
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
50
+ generation = generation[0][input_len:] # Remove input tokens from output
51
+ answer = processor.decode(generation, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
52
 
53
  st.success(f"βœ… Result: {answer}")