File size: 3,337 Bytes
f15b4fe 9e9b8c5 fad0c74 5e93509 fad0c74 a2d1403 f9968a3 fad0c74 9bc25a6 887ecc1 7db4531 887ecc1 9bc25a6 887ecc1 f15b4fe 3f8e55d 1b8efc6 8e6623c b1a44f6 5e93509 b1a44f6 1b8efc6 b1a44f6 3f8e55d 5e93509 12ab3b2 9381f96 12ab3b2 021fdec 887ecc1 ec7dc37 97e6b2c ac3772a 55b2a7e 9bc25a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import streamlit as st
from PIL import Image
import inference
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import io
import requests
import copy
import os
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import torch
#remove flash_attn for load model in cpu
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
# Initialize session state for model loading and to block re-running
if 'model_loaded' not in st.session_state:
st.session_state.model_loaded = False
# Function to load the model (e.g., Florence-2 model)
def load_model():
# Simulate model loading process
model_id = "microsoft/Florence-2-large"
#processor loading
st.session_state.processor = AutoProcessor.from_pretrained(model_id, torch_dtype=torch.qint8, trust_remote_code=True)
try:
os.mkdir("temp")
except:
pass
# Load the model normally
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): # workaround for unnecessary flash_attn requirement
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa", trust_remote_code=True)
# Apply dynamic quantization
Qmodel = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
del model
st.session_state.model = Qmodel
st.session_state.model_loaded = True
st.write("model loaded complete")
# Load the model only once
if not st.session_state.model_loaded:
with st.spinner('Loading model...'):
load_model()
# Initialize session state to block re-running
if 'has_run' not in st.session_state:
st.session_state.has_run = False
# Main UI container
st.markdown('<h3><center><b>VQA</b></center></h3>', unsafe_allow_html=True)
# Image upload area
uploaded_image = st.sidebar.file_uploader("Upload your image here", type=["jpg", "jpeg", "png"])
# Display the uploaded image and process it if available
if uploaded_image is not None:
image = Image.open(uploaded_image)
if image.mode != 'RGB':
image = image.convert('RGB')
image = image.resize((256,256))
# Save the image to a BytesIO object with a specific format
image_bytes = io.BytesIO()
image_format = image.format if image.format else 'PNG' # Default to 'PNG' if format is None
image.save(image_bytes, format=image_format)
image_bytes.seek(0)
# Display the image using Streamlit
st.image(image, caption="Uploaded Image", use_column_width=True)
image_binary = image_bytes.getvalue()
# Task prompt input
task_prompt = st.sidebar.text_input("Task Prompt", value="<MORE_DETAILED_CAPTION>")
# Additional text input (optional)
text_input = st.sidebar.text_area("Input Questions",value="<MORE_DETAILED_CAPTION>", height=20)
# Generate Caption button
if st.sidebar.button("Generate Caption", key="Generate"):
#st.write(task_prompt,"\n\n",text_input)
# inference.demo()
output=inference.run_example(image,st.session_state.model,st.session_state.processor,task_prompt,text_input)
st.write(output)
|