Spaces:
Running
Running
# app.py | |
import streamlit as st | |
from cerebras.cloud.sdk import Cerebras | |
import openai | |
import os | |
from dotenv import load_dotenv | |
from together import Together | |
# --- Assuming config.py and utils.py exist --- | |
import config | |
import utils | |
# --- BASE_PROMPT のインポート --- | |
try: | |
from prompt import BASE_PROMPT | |
except ImportError: | |
st.error( | |
"Error: 'prompt.py' not found or 'BASE_PROMPT' is not defined within it.") | |
st.stop() | |
# --- Import column rendering functions --- | |
from chat_column import render_chat_column | |
from image_column import render_image_column | |
# --- 環境変数読み込み --- | |
load_dotenv() | |
# --- Streamlit ページ設定 --- | |
st.set_page_config(page_icon="🤖", layout="wide", | |
page_title="Prompt & Image Generator") | |
# --- UI 表示 --- | |
utils.display_icon("🤖") | |
st.title("Prompt & Image Generator") | |
st.subheader("Generate text prompts (left) and edit/generate images (right)", | |
divider="orange", anchor=False) | |
# --- APIキーの処理 --- | |
# (API Key logic remains the same) | |
api_key_from_env = os.getenv("CEREBRAS_API_KEY") | |
show_api_key_input = not bool(api_key_from_env) | |
cerebras_api_key = None | |
together_api_key = os.getenv("TOGETHER_API_KEY") | |
# --- サイドバーの設定 --- | |
# (Sidebar logic remains the same) | |
with st.sidebar: | |
st.title("Settings") | |
# Cerebras Key Input | |
if show_api_key_input: | |
st.markdown("### :red[Enter your Cerebras API Key below]") | |
api_key_input = st.text_input( | |
"Cerebras API Key:", type="password", key="cerebras_api_key_input_field") | |
if api_key_input: | |
cerebras_api_key = api_key_input | |
else: | |
cerebras_api_key = api_key_from_env | |
st.success("✓ Cerebras API Key loaded from environment") | |
# Together Key Status | |
if not together_api_key: | |
st.warning( | |
"TOGETHER_API_KEY environment variable not set. Image generation (right column) will not work.", icon="⚠️") | |
else: | |
st.success("✓ Together API Key loaded from environment") | |
# Model selection | |
model_option = st.selectbox( | |
"Choose a LLM model:", | |
options=list(config.MODELS.keys()), | |
format_func=lambda x: config.MODELS[x]["name"], | |
key="model_select" | |
) | |
# Max tokens slider | |
max_tokens_range = config.MODELS[model_option]["tokens"] | |
default_tokens = min(2048, max_tokens_range) | |
max_tokens = st.slider( | |
"Max Tokens (LLM):", | |
min_value=512, | |
max_value=max_tokens_range, | |
value=default_tokens, | |
step=512, | |
help="Max tokens for the LLM's text prompt response." | |
) | |
use_optillm = st.toggle( | |
"Use Optillm (for Cerebras)", value=False) | |
# --- メインアプリケーションロジック --- | |
# Re-check Cerebras API key | |
if not cerebras_api_key and show_api_key_input and 'cerebras_api_key_input_field' in st.session_state and st.session_state.cerebras_api_key_input_field: | |
cerebras_api_key = st.session_state.cerebras_api_key_input_field | |
if not cerebras_api_key: | |
st.error("Cerebras API Key is required. Please enter it in the sidebar or set the CEREBRAS_API_KEY environment variable.", icon="🚨") | |
st.stop() | |
# APIクライアント初期化 | |
# (Client initialization remains the same) | |
llm_client = None | |
image_client = None | |
try: | |
if use_optillm: | |
if not hasattr(config, 'BASE_URL') or not config.BASE_URL: | |
st.error("Optillm selected, but BASE_URL is not configured.", icon="🚨") | |
st.stop() | |
llm_client = openai.OpenAI( | |
base_url=config.BASE_URL, api_key=cerebras_api_key) | |
else: | |
llm_client = Cerebras(api_key=cerebras_api_key) | |
if together_api_key: | |
image_client = Together(api_key=together_api_key) | |
except Exception as e: | |
st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨") | |
st.stop() | |
# --- Session State Initialization --- | |
# Initialize state variables if they don't exist | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "current_image_prompt_text" not in st.session_state: | |
st.session_state.current_image_prompt_text = "" | |
# --- MODIFICATION START --- | |
# Replace single image state with a list to store multiple images and their prompts | |
if "generated_images_list" not in st.session_state: | |
st.session_state.generated_images_list = [] # Initialize as empty list | |
# Remove old state variable if it exists (optional cleanup) | |
if "latest_generated_image" in st.session_state: | |
del st.session_state["latest_generated_image"] | |
# --- MODIFICATION END --- | |
if "selected_model" not in st.session_state: | |
st.session_state.selected_model = None | |
# --- Clear history if model changes --- | |
if st.session_state.selected_model != model_option: | |
st.session_state.messages = [] | |
st.session_state.current_image_prompt_text = "" | |
# --- MODIFICATION START --- | |
# Clear the list of generated images when model changes | |
st.session_state.generated_images_list = [] | |
# --- MODIFICATION END --- | |
st.session_state.selected_model = model_option | |
st.rerun() | |
# --- Define Main Columns --- | |
chat_col, image_col = st.columns([2, 1]) | |
# --- Render Columns using imported functions --- | |
with chat_col: | |
render_chat_column(st, llm_client, model_option, max_tokens, BASE_PROMPT) | |
with image_col: | |
render_image_column(st, image_client) # Pass the client | |