File size: 3,337 Bytes
f15b4fe
 
9e9b8c5
fad0c74
 
5e93509
fad0c74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2d1403
 
 
 
f9968a3
fad0c74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bc25a6
 
 
 
887ecc1
7db4531
887ecc1
9bc25a6
887ecc1
f15b4fe
3f8e55d
1b8efc6
 
8e6623c
b1a44f6
 
5e93509
b1a44f6
 
1b8efc6
b1a44f6
3f8e55d
5e93509
12ab3b2
9381f96
12ab3b2
021fdec
887ecc1
ec7dc37
97e6b2c
 
ac3772a
55b2a7e
9bc25a6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import streamlit as st
from PIL import Image
import inference
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import io
import requests
import copy
import os
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import torch

#remove flash_attn for load model in cpu
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    if not str(filename).endswith("modeling_florence2.py"):
        return get_imports(filename)
    imports = get_imports(filename)
    imports.remove("flash_attn")
    return imports

# Initialize session state for model loading and to block re-running
if 'model_loaded' not in st.session_state:
    st.session_state.model_loaded = False

# Function to load the model (e.g., Florence-2 model)
def load_model():
    # Simulate model loading process
    model_id = "microsoft/Florence-2-large"
    #processor loading
    st.session_state.processor = AutoProcessor.from_pretrained(model_id, torch_dtype=torch.qint8, trust_remote_code=True)
    try:
     os.mkdir("temp")
    except:
        pass
    
    # Load the model normally
    with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):  # workaround for unnecessary flash_attn requirement
        model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa", trust_remote_code=True)
    
    # Apply dynamic quantization
    Qmodel = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    del model
    st.session_state.model = Qmodel
    st.session_state.model_loaded = True
    st.write("model loaded complete")
# Load the model only once
if not st.session_state.model_loaded:
    with st.spinner('Loading model...'):
        load_model()


# Initialize session state to block re-running
if 'has_run' not in st.session_state:
    st.session_state.has_run = False

# Main UI container
st.markdown('<h3><center><b>VQA</b></center></h3>', unsafe_allow_html=True)
# Image upload area
uploaded_image = st.sidebar.file_uploader("Upload your image here", type=["jpg", "jpeg", "png"])
# Display the uploaded image and process it if available
if uploaded_image is not None:
    image = Image.open(uploaded_image)
    if image.mode != 'RGB':
            image = image.convert('RGB')
    image = image.resize((256,256))

    # Save the image to a BytesIO object with a specific format
    image_bytes = io.BytesIO()
    image_format = image.format if image.format else 'PNG'  # Default to 'PNG' if format is None
    image.save(image_bytes, format=image_format)
    image_bytes.seek(0)
    # Display the image using Streamlit
    st.image(image, caption="Uploaded Image", use_column_width=True)
    image_binary = image_bytes.getvalue()
    # Task prompt input
    task_prompt = st.sidebar.text_input("Task Prompt", value="<MORE_DETAILED_CAPTION>")
    # Additional text input (optional)
    text_input = st.sidebar.text_area("Input Questions",value="<MORE_DETAILED_CAPTION>", height=20)
    # Generate Caption button
    if st.sidebar.button("Generate Caption", key="Generate"):
        #st.write(task_prompt,"\n\n",text_input)
        # inference.demo()
        output=inference.run_example(image,st.session_state.model,st.session_state.processor,task_prompt,text_input)
        st.write(output)