# app.py import streamlit as st from cerebras.cloud.sdk import Cerebras import openai import os from dotenv import load_dotenv from together import Together # --- Assuming config.py and utils.py exist --- import config import utils # --- BASE_PROMPT のインポート --- try: from prompt import BASE_PROMPT except ImportError: st.error( "Error: 'prompt.py' not found or 'BASE_PROMPT' is not defined within it.") st.stop() # --- Import column rendering functions --- from chat_column import render_chat_column from image_column import render_image_column # --- 環境変数読み込み --- load_dotenv() # --- Streamlit ページ設定 --- st.set_page_config(page_icon="🤖", layout="wide", page_title="Prompt & Image Generator") # --- UI 表示 --- utils.display_icon("🤖") st.title("Prompt & Image Generator") st.subheader("Generate text prompts (left) and edit/generate images (right)", divider="orange", anchor=False) # --- APIキーの処理 --- # (API Key logic remains the same) api_key_from_env = os.getenv("CEREBRAS_API_KEY") show_api_key_input = not bool(api_key_from_env) cerebras_api_key = None together_api_key = os.getenv("TOGETHER_API_KEY") # --- サイドバーの設定 --- # (Sidebar logic remains the same) with st.sidebar: st.title("Settings") # Cerebras Key Input if show_api_key_input: st.markdown("### :red[Enter your Cerebras API Key below]") api_key_input = st.text_input( "Cerebras API Key:", type="password", key="cerebras_api_key_input_field") if api_key_input: cerebras_api_key = api_key_input else: cerebras_api_key = api_key_from_env st.success("✓ Cerebras API Key loaded from environment") # Together Key Status if not together_api_key: st.warning( "TOGETHER_API_KEY environment variable not set. Image generation (right column) will not work.", icon="⚠️") else: st.success("✓ Together API Key loaded from environment") # Model selection model_option = st.selectbox( "Choose a LLM model:", options=list(config.MODELS.keys()), format_func=lambda x: config.MODELS[x]["name"], key="model_select" ) # Max tokens slider max_tokens_range = config.MODELS[model_option]["tokens"] default_tokens = min(2048, max_tokens_range) max_tokens = st.slider( "Max Tokens (LLM):", min_value=512, max_value=max_tokens_range, value=default_tokens, step=512, help="Max tokens for the LLM's text prompt response." ) use_optillm = st.toggle( "Use Optillm (for Cerebras)", value=False) # --- メインアプリケーションロジック --- # Re-check Cerebras API key if not cerebras_api_key and show_api_key_input and 'cerebras_api_key_input_field' in st.session_state and st.session_state.cerebras_api_key_input_field: cerebras_api_key = st.session_state.cerebras_api_key_input_field if not cerebras_api_key: st.error("Cerebras API Key is required. Please enter it in the sidebar or set the CEREBRAS_API_KEY environment variable.", icon="🚨") st.stop() # APIクライアント初期化 # (Client initialization remains the same) llm_client = None image_client = None try: if use_optillm: if not hasattr(config, 'BASE_URL') or not config.BASE_URL: st.error("Optillm selected, but BASE_URL is not configured.", icon="🚨") st.stop() llm_client = openai.OpenAI( base_url=config.BASE_URL, api_key=cerebras_api_key) else: llm_client = Cerebras(api_key=cerebras_api_key) if together_api_key: image_client = Together(api_key=together_api_key) except Exception as e: st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨") st.stop() # --- Session State Initialization --- # Initialize state variables if they don't exist if "messages" not in st.session_state: st.session_state.messages = [] if "current_image_prompt_text" not in st.session_state: st.session_state.current_image_prompt_text = "" # --- MODIFICATION START --- # Replace single image state with a list to store multiple images and their prompts if "generated_images_list" not in st.session_state: st.session_state.generated_images_list = [] # Initialize as empty list # Remove old state variable if it exists (optional cleanup) if "latest_generated_image" in st.session_state: del st.session_state["latest_generated_image"] # --- MODIFICATION END --- if "selected_model" not in st.session_state: st.session_state.selected_model = None # --- Clear history if model changes --- if st.session_state.selected_model != model_option: st.session_state.messages = [] st.session_state.current_image_prompt_text = "" # --- MODIFICATION START --- # Clear the list of generated images when model changes st.session_state.generated_images_list = [] # --- MODIFICATION END --- st.session_state.selected_model = model_option st.rerun() # --- Define Main Columns --- chat_col, image_col = st.columns([2, 1]) # --- Render Columns using imported functions --- with chat_col: render_chat_column(st, llm_client, model_option, max_tokens, BASE_PROMPT) with image_col: render_image_column(st, image_client) # Pass the client