Spaces:
Runtime error
Runtime error
File size: 4,674 Bytes
324f080 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from io import BytesIO
import streamlit as st
import pandas as pd
import json
import os
import numpy as np
from streamlit.elements import markdown
from PIL import Image
from model.flax_clip_vision_mbart.modeling_clip_vision_mbart import (
FlaxCLIPVisionMBartForConditionalGeneration,
)
from transformers import MBart50TokenizerFast
from utils import (
get_transformed_image,
)
import matplotlib.pyplot as plt
from mtranslate import translate
from session import _get_state
state = _get_state()
@st.cache
def load_model(ckpt):
return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt)
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
language_mapping = {
"en": "en_XX",
"de": "de_DE",
"fr": "fr_XX",
"es": "es_XX"
}
code_to_name = {
"en": "English",
"fr": "French",
"de": "German",
"es": "Spanish",
}
@st.cache(persist=True)
def generate_sequence(pixel_values, lang_code, num_beams):
lang_code = language_mapping[lang_code]
output_ids = model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams)
print(output_ids)
output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
return output_sequence
def read_markdown(path, parent="./sections/"):
with open(os.path.join(parent, path)) as f:
return f.read()
checkpoints = ["./ckpt/ckpt-22499"] # TODO: Maybe add more checkpoints?
dummy_data = pd.read_csv("reference.tsv", sep="\t")
st.set_page_config(
page_title="Multilingual Image Captioning",
layout="wide",
initial_sidebar_state="collapsed",
)
st.title("Multilingual Image Captioning")
st.write(
"[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)"
)
st.sidebar.title("Settings")
num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
with st.beta_expander("Usage"):
st.markdown(read_markdown("usage.md"))
first_index = 20
# Init Session State
if state.image_file is None:
state.image_file = dummy_data.loc[first_index, "image_file"]
state.caption = dummy_data.loc[first_index, "caption"].strip("- ")
state.lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("images", state.image_file)
image = plt.imread(image_path)
state.image = image
col1, col2 = st.beta_columns([6, 4])
if col2.button("Get a random example"):
sample = dummy_data.sample(1).reset_index()
state.image_file = sample.loc[0, "image_file"]
state.caption = sample.loc[0, "caption"].strip("- ")
state.lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("images", state.image_file)
image = plt.imread(image_path)
state.image = image
col2.write("OR")
uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
state.image_file = os.path.join("images", uploaded_file.name)
state.image = np.array(Image.open(uploaded_file))
transformed_image = get_transformed_image(state.image)
# Display Image
col1.image(state.image, use_column_width="auto")
# Display Reference Caption
col2.write("**Reference Caption**: " + state.caption)
col2.markdown(
f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}"""
)
# Select Language
options = list(code_to_name.keys())
lang_id = col2.selectbox(
"Language",
index=options.index(state.lang_id),
options=options,
format_func=lambda x: code_to_name[x],
)
# Display Top-5 Predictions
with st.spinner("Loading model..."):
model = load_model(checkpoints[0])
sequence = ['']
if col2.button("Generate Caption"):
with st.spinner("Generating Sequence..."):
sequence = generate_sequence(transformed_image, lang_id, num_beams)
# print(sequence)
if sequence!=['']:
st.write(
"**Generated Caption**: "+sequence[0]
)
st.write(
"**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
)
st.write(read_markdown("abstract.md"))
st.write(read_markdown("caveats.md"))
# st.write("# Methodology")
# st.image(
# "./misc/Multilingual-IC.png", caption="Seq2Seq model for Image-text Captioning."
# )
st.markdown(read_markdown("pretraining.md"))
st.write(read_markdown("challenges.md"))
st.write(read_markdown("social_impact.md"))
st.write(read_markdown("references.md"))
# st.write(read_markdown("checkpoints.md"))
st.write(read_markdown("acknowledgements.md"))
|