Spaces:
Running
Running
File size: 5,378 Bytes
b026171 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# app.py
import streamlit as st
from cerebras.cloud.sdk import Cerebras
import openai
import os
from dotenv import load_dotenv
from together import Together
# --- Assuming config.py and utils.py exist ---
import config
import utils
# --- BASE_PROMPT のインポート ---
try:
from prompt import BASE_PROMPT
except ImportError:
st.error(
"Error: 'prompt.py' not found or 'BASE_PROMPT' is not defined within it.")
st.stop()
# --- Import column rendering functions ---
from chat_column import render_chat_column
from image_column import render_image_column
# --- 環境変数読み込み ---
load_dotenv()
# --- Streamlit ページ設定 ---
st.set_page_config(page_icon="🤖", layout="wide",
page_title="Prompt & Image Generator")
# --- UI 表示 ---
utils.display_icon("🤖")
st.title("Prompt & Image Generator")
st.subheader("Generate text prompts (left) and edit/generate images (right)",
divider="orange", anchor=False)
# --- APIキーの処理 ---
# (API Key logic remains the same)
api_key_from_env = os.getenv("CEREBRAS_API_KEY")
show_api_key_input = not bool(api_key_from_env)
cerebras_api_key = None
together_api_key = os.getenv("TOGETHER_API_KEY")
# --- サイドバーの設定 ---
# (Sidebar logic remains the same)
with st.sidebar:
st.title("Settings")
# Cerebras Key Input
if show_api_key_input:
st.markdown("### :red[Enter your Cerebras API Key below]")
api_key_input = st.text_input(
"Cerebras API Key:", type="password", key="cerebras_api_key_input_field")
if api_key_input:
cerebras_api_key = api_key_input
else:
cerebras_api_key = api_key_from_env
st.success("✓ Cerebras API Key loaded from environment")
# Together Key Status
if not together_api_key:
st.warning(
"TOGETHER_API_KEY environment variable not set. Image generation (right column) will not work.", icon="⚠️")
else:
st.success("✓ Together API Key loaded from environment")
# Model selection
model_option = st.selectbox(
"Choose a LLM model:",
options=list(config.MODELS.keys()),
format_func=lambda x: config.MODELS[x]["name"],
key="model_select"
)
# Max tokens slider
max_tokens_range = config.MODELS[model_option]["tokens"]
default_tokens = min(2048, max_tokens_range)
max_tokens = st.slider(
"Max Tokens (LLM):",
min_value=512,
max_value=max_tokens_range,
value=default_tokens,
step=512,
help="Max tokens for the LLM's text prompt response."
)
use_optillm = st.toggle(
"Use Optillm (for Cerebras)", value=False)
# --- メインアプリケーションロジック ---
# Re-check Cerebras API key
if not cerebras_api_key and show_api_key_input and 'cerebras_api_key_input_field' in st.session_state and st.session_state.cerebras_api_key_input_field:
cerebras_api_key = st.session_state.cerebras_api_key_input_field
if not cerebras_api_key:
st.error("Cerebras API Key is required. Please enter it in the sidebar or set the CEREBRAS_API_KEY environment variable.", icon="🚨")
st.stop()
# APIクライアント初期化
# (Client initialization remains the same)
llm_client = None
image_client = None
try:
if use_optillm:
if not hasattr(config, 'BASE_URL') or not config.BASE_URL:
st.error("Optillm selected, but BASE_URL is not configured.", icon="🚨")
st.stop()
llm_client = openai.OpenAI(
base_url=config.BASE_URL, api_key=cerebras_api_key)
else:
llm_client = Cerebras(api_key=cerebras_api_key)
if together_api_key:
image_client = Together(api_key=together_api_key)
except Exception as e:
st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨")
st.stop()
# --- Session State Initialization ---
# Initialize state variables if they don't exist
if "messages" not in st.session_state:
st.session_state.messages = []
if "current_image_prompt_text" not in st.session_state:
st.session_state.current_image_prompt_text = ""
# --- MODIFICATION START ---
# Replace single image state with a list to store multiple images and their prompts
if "generated_images_list" not in st.session_state:
st.session_state.generated_images_list = [] # Initialize as empty list
# Remove old state variable if it exists (optional cleanup)
if "latest_generated_image" in st.session_state:
del st.session_state["latest_generated_image"]
# --- MODIFICATION END ---
if "selected_model" not in st.session_state:
st.session_state.selected_model = None
# --- Clear history if model changes ---
if st.session_state.selected_model != model_option:
st.session_state.messages = []
st.session_state.current_image_prompt_text = ""
# --- MODIFICATION START ---
# Clear the list of generated images when model changes
st.session_state.generated_images_list = []
# --- MODIFICATION END ---
st.session_state.selected_model = model_option
st.rerun()
# --- Define Main Columns ---
chat_col, image_col = st.columns([2, 1])
# --- Render Columns using imported functions ---
with chat_col:
render_chat_column(st, llm_client, model_option, max_tokens, BASE_PROMPT)
with image_col:
render_image_column(st, image_client) # Pass the client
|