File size: 5,860 Bytes
4be498f
 
 
 
8b3ed0b
 
 
 
 
 
4be498f
 
 
 
 
 
 
 
8b3ed0b
 
3ca0b1c
4be498f
f0aa736
 
 
 
 
9a63147
 
f0aa736
 
 
141f85a
f0aa736
 
 
 
 
 
 
 
 
 
 
 
 
 
07aea06
 
 
 
 
 
f0aa736
 
 
 
 
07aea06
f0aa736
 
 
 
 
 
 
 
 
 
4be498f
 
 
8b3ed0b
8e6266f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b3ed0b
 
 
 
 
 
 
 
 
 
 
 
 
d4edf4a
 
 
 
ee13260
 
 
 
d4edf4a
 
 
 
 
 
 
 
8b3ed0b
d4edf4a
3872559
8e6266f
fc53eb3
d4edf4a
fc53eb3
d4edf4a
fc53eb3
 
 
ee13260
 
d6d53fa
 
 
 
ee13260
 
 
 
 
 
 
8b3ed0b
fc53eb3
d4edf4a
fc53eb3
 
7498bae
8e6266f
 
fc53eb3
8e6266f
 
 
 
d4edf4a
8e6266f
fc53eb3
0efa8d3
d4edf4a
 
 
 
 
8e6266f
4be498f
 
 
 
8b3ed0b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#------------------------------------------------------------------------
# Import
#------------------------------------------------------------------------

import streamlit as st
import requests
from PIL import Image
import io
import os

#------------------------------------------------------------------------
# HF API
#------------------------------------------------------------------------

# Retrieve the HF API key from environment variables 
hf_api_key = os.getenv('HF_API_KEY')
if not hf_api_key:
    raise ValueError("HF_API_KEY not set in environment variables")

API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
headers = {"Authorization": f"Bearer {hf_api_key}"}

#------------------------------------------------------------------------
# Configurations
#------------------------------------------------------------------------
# Streamlit page setup
st.set_page_config(
    page_title="Stxtement | Image Generation", 
    page_icon=":art:", 
    layout="centered", 
    initial_sidebar_state="auto",
    menu_items={
        'Get Help': 'mailto:[email protected]',
        'About': "This app is built to support spreadsheet analysis"
    }
)

#------------------------------------------------------------------------
# Sidebar
#------------------------------------------------------------------------
with st.sidebar:
    # Password input field
    # password = st.text_input("Enter Password:", type="password")
    
    # Set the desired width in pixels
    image_width = 300  
    # Define the path to the image
    # image_path = "mimtss.png"
    # # Display the image
    # st.image(image_path, width=image_width)
    
    # Set the title
    st.title("MTSS.ai")
    
    # Toggle for Help and Report a Bug
    with st.expander("Need help and report a bug"):
        st.write("""
        **Contact**: Cheyne LeVesseur, PhD  
        **Email**: [email protected]
        """)
    st.divider()
    st.subheader('User Instructions')
    
    # Principles text with Markdown formatting
    User_Instructions = """
    Enter a detailed description of the image you want to generate, and the app will create it based on your prompt.
    """
    st.markdown(User_Instructions)

#------------------------------------------------------------------------
# Define functions
#------------------------------------------------------------------------

# SIMPLE CODE
# def query(payload):
#     response = requests.post(API_URL, headers=headers, json=payload)
#     if response.status_code != 200:
#         st.error(f"Error: {response.status_code} - {response.text}")
#         return None
#     return response.content

# def generate_image(prompt):
#     image_bytes = query({"inputs": prompt})
#     if image_bytes:
#         return Image.open(io.BytesIO(image_bytes))
#     return None

# def main():
#     st.title("Stxtement | Image Generation")

#     prompt = st.text_input("Enter a prompt for image generation:")
    
#     if st.button("Generate Image"):
#         if prompt:
#             image = generate_image(prompt)
#             if image:
#                 st.image(image, caption="Generated Image")
#         else:
#             st.warning("Please enter a prompt.")

# COMPREHENSIVE CODE
def query(payload):
    response = requests.post(API_URL, headers=headers, json=payload)
    if response.status_code != 200:
        st.error(f"Error: {response.status_code} - {response.text}")
        return None
    return response.content

def generate_image(prompt):
    image_bytes = query({"inputs": prompt})
    if image_bytes:
        return Image.open(io.BytesIO(image_bytes))
    return None

def generate_image_callback():
    prompt = st.session_state.get("prompt_input", "")
    if prompt:
        st.session_state["prompt"] = prompt
        # The spinner is handled outside this function
        image = generate_image(prompt)
        if image:
            st.session_state["image"] = image
    else:
        st.warning("Please enter a prompt.")

def reset_callback():
    # Clear session state variables
    st.session_state["prompt"] = ""
    st.session_state["prompt_input"] = ""
    st.session_state["image"] = None

def main():
    st.title("Stxtement | Image Generation")

    # Input field for the prompt, tied to session state
    st.text_input(
        "Enter a prompt for image generation:",
        value=st.session_state.get("prompt_input", ""),
        key="prompt_input"
    )

    # Generate Image button
    generate_button_clicked = st.button("Generate Image")

    # Create a placeholder for the spinner after the button
    spinner_placeholder = st.empty()

    if generate_button_clicked:
        if st.session_state.get("prompt_input", ""):
            with spinner_placeholder:
                with st.spinner('Generating image...'):
                    generate_image_callback()
        else:
            st.warning("Please enter a prompt.")

    # Show the image if it exists in the session state
    if st.session_state.get("image"):
        st.image(st.session_state["image"], caption="Generated Image")

        # Download button
        image_bytes = io.BytesIO()
        st.session_state["image"].save(image_bytes, format='PNG')

        st.download_button(
            label="Download Image",
            data=image_bytes.getvalue(),
            file_name="generated_image.png",
            mime="image/png"
        )

        # Reset button
        st.button("Reset", on_click=reset_callback)

    # If no image but prompt exists, show reset button to clear prompt
    elif st.session_state.get("prompt_input"):
        st.button("Reset", on_click=reset_callback)

#------------------------------------------------------------------------
# Main Guard
#------------------------------------------------------------------------

if __name__ == "__main__":
    main()