File size: 5,860 Bytes
4be498f 8b3ed0b 4be498f 8b3ed0b 3ca0b1c 4be498f f0aa736 9a63147 f0aa736 141f85a f0aa736 07aea06 f0aa736 07aea06 f0aa736 4be498f 8b3ed0b 8e6266f 8b3ed0b d4edf4a ee13260 d4edf4a 8b3ed0b d4edf4a 3872559 8e6266f fc53eb3 d4edf4a fc53eb3 d4edf4a fc53eb3 ee13260 d6d53fa ee13260 8b3ed0b fc53eb3 d4edf4a fc53eb3 7498bae 8e6266f fc53eb3 8e6266f d4edf4a 8e6266f fc53eb3 0efa8d3 d4edf4a 8e6266f 4be498f 8b3ed0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
#------------------------------------------------------------------------
# Import
#------------------------------------------------------------------------
import streamlit as st
import requests
from PIL import Image
import io
import os
#------------------------------------------------------------------------
# HF API
#------------------------------------------------------------------------
# Retrieve the HF API key from environment variables
hf_api_key = os.getenv('HF_API_KEY')
if not hf_api_key:
raise ValueError("HF_API_KEY not set in environment variables")
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
headers = {"Authorization": f"Bearer {hf_api_key}"}
#------------------------------------------------------------------------
# Configurations
#------------------------------------------------------------------------
# Streamlit page setup
st.set_page_config(
page_title="Stxtement | Image Generation",
page_icon=":art:",
layout="centered",
initial_sidebar_state="auto",
menu_items={
'Get Help': 'mailto:[email protected]',
'About': "This app is built to support spreadsheet analysis"
}
)
#------------------------------------------------------------------------
# Sidebar
#------------------------------------------------------------------------
with st.sidebar:
# Password input field
# password = st.text_input("Enter Password:", type="password")
# Set the desired width in pixels
image_width = 300
# Define the path to the image
# image_path = "mimtss.png"
# # Display the image
# st.image(image_path, width=image_width)
# Set the title
st.title("MTSS.ai")
# Toggle for Help and Report a Bug
with st.expander("Need help and report a bug"):
st.write("""
**Contact**: Cheyne LeVesseur, PhD
**Email**: [email protected]
""")
st.divider()
st.subheader('User Instructions')
# Principles text with Markdown formatting
User_Instructions = """
Enter a detailed description of the image you want to generate, and the app will create it based on your prompt.
"""
st.markdown(User_Instructions)
#------------------------------------------------------------------------
# Define functions
#------------------------------------------------------------------------
# SIMPLE CODE
# def query(payload):
# response = requests.post(API_URL, headers=headers, json=payload)
# if response.status_code != 200:
# st.error(f"Error: {response.status_code} - {response.text}")
# return None
# return response.content
# def generate_image(prompt):
# image_bytes = query({"inputs": prompt})
# if image_bytes:
# return Image.open(io.BytesIO(image_bytes))
# return None
# def main():
# st.title("Stxtement | Image Generation")
# prompt = st.text_input("Enter a prompt for image generation:")
# if st.button("Generate Image"):
# if prompt:
# image = generate_image(prompt)
# if image:
# st.image(image, caption="Generated Image")
# else:
# st.warning("Please enter a prompt.")
# COMPREHENSIVE CODE
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code != 200:
st.error(f"Error: {response.status_code} - {response.text}")
return None
return response.content
def generate_image(prompt):
image_bytes = query({"inputs": prompt})
if image_bytes:
return Image.open(io.BytesIO(image_bytes))
return None
def generate_image_callback():
prompt = st.session_state.get("prompt_input", "")
if prompt:
st.session_state["prompt"] = prompt
# The spinner is handled outside this function
image = generate_image(prompt)
if image:
st.session_state["image"] = image
else:
st.warning("Please enter a prompt.")
def reset_callback():
# Clear session state variables
st.session_state["prompt"] = ""
st.session_state["prompt_input"] = ""
st.session_state["image"] = None
def main():
st.title("Stxtement | Image Generation")
# Input field for the prompt, tied to session state
st.text_input(
"Enter a prompt for image generation:",
value=st.session_state.get("prompt_input", ""),
key="prompt_input"
)
# Generate Image button
generate_button_clicked = st.button("Generate Image")
# Create a placeholder for the spinner after the button
spinner_placeholder = st.empty()
if generate_button_clicked:
if st.session_state.get("prompt_input", ""):
with spinner_placeholder:
with st.spinner('Generating image...'):
generate_image_callback()
else:
st.warning("Please enter a prompt.")
# Show the image if it exists in the session state
if st.session_state.get("image"):
st.image(st.session_state["image"], caption="Generated Image")
# Download button
image_bytes = io.BytesIO()
st.session_state["image"].save(image_bytes, format='PNG')
st.download_button(
label="Download Image",
data=image_bytes.getvalue(),
file_name="generated_image.png",
mime="image/png"
)
# Reset button
st.button("Reset", on_click=reset_callback)
# If no image but prompt exists, show reset button to clear prompt
elif st.session_state.get("prompt_input"):
st.button("Reset", on_click=reset_callback)
#------------------------------------------------------------------------
# Main Guard
#------------------------------------------------------------------------
if __name__ == "__main__":
main() |