Spaces:
Sleeping
Sleeping
import streamlit as st | |
import time | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import AutoTokenizer, AutoModel, AutoConfig | |
import torch | |
from tqdm import tqdm | |
import gan_cls_768 | |
from torch.autograd import Variable | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def clean(txt): | |
txt = txt.lower() | |
txt = txt.strip() | |
txt = txt.strip('.') | |
return txt | |
max_len = 76 | |
def tokenize(tokenizer, txt): | |
return tokenizer( | |
txt, | |
max_length=max_len, | |
padding='max_length', | |
truncation=True, | |
return_offsets_mapping=False | |
) | |
def encode(model, tokenizer, txt): | |
txt = clean(txt) | |
txt_tokenized = tokenize(tokenizer, txt) | |
for k, v in txt_tokenized.items(): | |
txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None] | |
model.eval() | |
with torch.no_grad(): | |
encoded = model(**txt_tokenized) | |
return encoded.last_hidden_state.squeeze()[0].cpu().numpy() | |
def get_model_roberta(): | |
model_name = 'roberta-base' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained( | |
model_name, | |
config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device) | |
return model, tokenizer | |
def get_model_gan(): | |
generator = torch.nn.DataParallel(gan_cls_768.generator().to(device)) | |
generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu'))) | |
generator.eval() | |
return generator | |
def generate_image(text, n): | |
model, tokenizer = get_model_roberta() | |
generator = get_model_gan() | |
embed = encode(model, tokenizer, text) | |
embed2 = torch.FloatTensor(embed) | |
embed2 = embed2.unsqueeze(0) | |
right_embed = Variable(embed2.float()).to(device) | |
l = [] | |
for i in tqdm(range(n)): | |
noise = Variable(torch.randn(1, 100)).to(device) | |
noise = noise.view(noise.size(0), 100, 1, 1) | |
fake_images = generator(right_embed, noise) | |
for idx, image in enumerate(fake_images): | |
im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy()) | |
l.append(im) | |
return l | |
st.set_page_config( | |
page_title="ImageGen", | |
page_icon="🧊", | |
layout="centered", | |
initial_sidebar_state="expanded", | |
) | |
hide_st_style = """ | |
<style> | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
header {visibility: hidden;} | |
</style> | |
""" | |
st.markdown(hide_st_style, unsafe_allow_html=True) | |
examples = [ | |
"this petal has gorgeous purple petals and a long green pedicel", | |
"this petal has gorgeous green petals and a long green pedicel", | |
"a couple thin, sharp, knife-like petals that have a sharp, purple, needle-like center.", | |
"salmon colored round petals with veins of dark pink throughout all combined in the center with a pale yellow pistol and pollen tube.", | |
"this vivid pink flower is composed of several blossoms with ruffled petals above and below a bulbous yellow-streaked center.", | |
"delicated pink petals clumped on one green pedicel with small sepals.", | |
"the flower has big yellow upright petals attached to a thick vine", | |
"these bright flowers have many yellow strip petals and stamen.", | |
"a large red flower with black dots and a very long stigmas.", | |
"this flower has petals that are pink and bell shaped", | |
"this flower has petals that are yellow and has black lines", | |
"the pink flower has bell shaped petal that is soft, smooth and enclosing stamen sticking out from the centre", | |
"this flower has orange petals with many dark spots, white stamen, and dark anthers.", | |
"this flower has petals that are white and has a yellow style", | |
"his flower has petals that are orange and are very thin", | |
"a flower with singular conical purple petal and large white pistil.", | |
"this flower is yellow in color, and has petals that are very skinny.", | |
"a velvet large flower with a dark marking and a green stem.", | |
"this flower is yellow in color, and has petals that are very skinny.", | |
"the flower has bright yellow soft petals with yellow stamens.", | |
"this flower has petals that are pink and has red stamen", | |
"this flower has petals that are purple and have dark lines", | |
"this purple flower has pointy short petals and green sepal.", | |
"this flower has petals that are purple and has a yellow style", | |
"this flower is yellow in color, with petals that are skinny and pointed.", | |
"the petals on this flower are orange with a purple pistil.", | |
"this flower features a prominent ovary covered with dozens of small stamens featuring thin white petals.", | |
"this purple color flower has the simple row of petals arranged in the circle with the red color pistils at the center", | |
"this flower has petals that are red and are very thin", | |
"a flower with many folded over bright yellow petals", | |
"a flower with no visible petals and purple pistils in the center.", | |
"a star shaped flower with five white petals with purple lines running through them.", | |
"the petals on this flower are bright yellow in color and there are two rows. the bottom layer lays flat, while the top layer is shaped like a bowl around the pistil.", | |
"this flower features a purple stigma surrounded by pointed waxy orange petals.", | |
"this flower is yellow and brown in color, with petals that are oval shaped.", | |
"this flower has petals that are white and has a yellow stigma", | |
"a flower with folded open and back red petals with black spots and think red anther", | |
"this flower has large light red petals and a few white stamen in the center", | |
"this flower has bright orange tubular petals rising out of a thick receptacle on a green pedicel.", | |
"this flower is a beauty with light red leaves in an equal circle.", | |
"a flower with an open conical red petal and white anther supported by red filaments", | |
"this flower is red in color, with petals that are bell shaped.", | |
"the petals of this flower are yellow with a long stigma", | |
] | |
def app(): | |
st.title("Text to Flower") | |
st.markdown( | |
""" | |
**Demo for Paper:** Synthesizing Realistic Images from Textual Descriptions: A Transformer-Based GAN Approach. | |
Presented in *"International Conference on Next-Generation Computing, IoT and Machine Learning (NCIM 2023)"* | |
""" | |
) | |
se = st.selectbox("Select from example", | |
examples) | |
row1_col1, row1_col2 = st.columns([2, 3]) | |
width = 950 | |
height = 600 | |
with row1_col1: | |
caption = st.text_area("Write your flower description here:", se, height=120) | |
backend = st.selectbox( | |
"Select a Model", ["Convolutional GAN with RoBERTa", ], index=0 | |
) | |
if st.button("Generate", type="primary"): | |
with st.spinner("Generating Flower Images..."): | |
imgs = generate_image(caption, 12) | |
#ss = st.success("Scores predicted successfully!") | |
with row1_col2: | |
st.markdown("Generated Flower Images:") | |
fig, ax = plt.subplots(nrows=3, ncols=4) | |
ax = ax.flatten() | |
for idx, ax in enumerate(ax): | |
ax.imshow(imgs[idx]) | |
ax.axis('off') | |
fig.tight_layout() | |
st.pyplot(fig) | |
# with row1_col2: | |
# img1 = Image.open('./images/t2i/1.jpg') | |
# img2 = Image.open('./images/t2i/2.jpg') | |
# img3 = Image.open('./images/t2i/3.jpg') | |
# img4 = Image.open('./images/t2i/4.jpg') | |
# cont = st.container() | |
# with cont: | |
# st.write("This is a container with a caption like a button.") | |
# col1, col2, col3, col4 = st.columns(4) | |
# with col1: | |
# st.image(img1, width=128) | |
# with col2: | |
# st.image(img2, width=128) | |
# with col3: | |
# st.image(img3, width=128) | |
# with col4: | |
# st.image(img4, width=128) | |
app() | |
# # Display a footer with links and credits | |
st.markdown("---") | |
st.markdown("Back to [www.shamimahamed.com](https://www.shamimahamed.com/).") | |
# #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)") | |