test-zerogpu-2 / app.py
nyasukun's picture
Update app.py
b37f1c5 verified
raw
history blame
6.43 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, pipeline
import logging
import spaces
# ロガーの設定
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# シンプルなモデル定義(3つのローカルモデル)
TEXT_GENERATION_MODELS = [
{
"name": "Llama-2",
"description": "Known for its robust performance in content analysis",
"model_path": "meta-llama/Llama-2-7b-hf"
},
{
"name": "Mistral-7B",
"description": "Offers precise and detailed text evaluation",
"model_path": "mistralai/Mistral-7B-v0.1"
}
]
CLASSIFICATION_MODELS = [
{
"name": "Toxic-BERT",
"description": "Fine-tuned for toxic content detection",
"model_path": "unitary/toxic-bert"
}
]
# グローバル変数でモデルとトークナイザを管理
tokenizers = {}
pipelines = {}
def preload_models():
"""アプリケーション起動時にモデルを事前ロード"""
logger.info("Preloading models at application startup...")
# テキスト生成モデル
for model in TEXT_GENERATION_MODELS:
model_path = model["model_path"]
try:
logger.info(f"Preloading text generation model: {model_path}")
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
pipelines[model_path] = pipeline(
"text-generation",
model=model_path,
tokenizer=tokenizers[model_path],
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
logger.info(f"Model preloaded successfully: {model_path}")
except Exception as e:
logger.error(f"Error preloading model {model_path}: {str(e)}")
# 分類モデル
for model in CLASSIFICATION_MODELS:
model_path = model["model_path"]
try:
logger.info(f"Preloading classification model: {model_path}")
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
pipelines[model_path] = pipeline(
"text-classification",
model=model_path,
tokenizer=tokenizers[model_path],
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
logger.info(f"Model preloaded successfully: {model_path}")
except Exception as e:
logger.error(f"Error preloading model {model_path}: {str(e)}")
@spaces.GPU
def generate_text(model_path, text):
"""テキスト生成の実行"""
try:
logger.info(f"Running text generation with {model_path}")
outputs = pipelines[model_path](
text,
max_new_tokens=100,
do_sample=False,
num_return_sequences=1
)
return outputs[0]["generated_text"]
except Exception as e:
logger.error(f"Error in text generation with {model_path}: {str(e)}")
return f"Error: {str(e)}"
@spaces.GPU
def classify_text(model_path, text):
"""テキスト分類の実行"""
try:
logger.info(f"Running classification with {model_path}")
result = pipelines[model_path](text)
return str(result)
except Exception as e:
logger.error(f"Error in classification with {model_path}: {str(e)}")
return f"Error: {str(e)}"
def handle_invoke(text):
"""すべてのモデルで分析を実行"""
results = []
# テキスト生成モデルの実行
for model in TEXT_GENERATION_MODELS:
model_path = model["model_path"]
result = generate_text(model_path, text)
results.append(result)
# 分類モデルの実行
for model in CLASSIFICATION_MODELS:
model_path = model["model_path"]
result = classify_text(model_path, text)
results.append(result)
return results
def create_ui():
"""UIの作成"""
with gr.Blocks() as demo:
# ヘッダー
gr.Markdown("""
# Toxic Eye (3 Models Version)
This system evaluates the toxicity level of input text using 3 local models.
""")
# 入力セクション
with gr.Row():
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter text to analyze...",
lines=3
)
# 実行ボタン
with gr.Row():
invoke_button = gr.Button(
"Analyze Text",
variant="primary",
size="lg"
)
# モデル出力表示エリア
gen_outputs = []
class_outputs = []
with gr.Tabs():
# テキスト生成モデルのタブ
with gr.Tab("Text Generation Models"):
for model in TEXT_GENERATION_MODELS:
with gr.Group():
gr.Markdown(f"### {model['name']}")
output = gr.Textbox(
label=f"{model['name']} Output",
lines=5,
interactive=False,
info=model["description"]
)
gen_outputs.append(output)
# 分類モデルのタブ
with gr.Tab("Classification Models"):
for model in CLASSIFICATION_MODELS:
with gr.Group():
gr.Markdown(f"### {model['name']}")
output = gr.Textbox(
label=f"{model['name']} Output",
lines=5,
interactive=False,
info=model["description"]
)
class_outputs.append(output)
# イベント接続
invoke_button.click(
fn=handle_invoke,
inputs=[input_text],
outputs=gen_outputs + class_outputs
)
return demo
def main():
# モデルを事前ロード
preload_models()
# UIを作成して起動
demo = create_ui()
demo.launch()
if __name__ == "__main__":
main()