test-zerogpu-2 / app.py
nyasukun's picture
Update app.py
03bbc94 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, pipeline
from huggingface_hub import InferenceClient
import logging
import spaces
# ロガーの設定
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# モデル定義(ローカルモデルとAPIモデルの両方)
TEXT_GENERATION_MODELS = [
{
"name": "Llama-2",
"description": "Known for its robust performance in content analysis",
"type": "local",
"model_path": "meta-llama/Llama-2-7b-hf"
},
{
"name": "Mistral-7B",
"description": "Offers precise and detailed text evaluation",
"type": "local",
"model_path": "mistralai/Mistral-7B-v0.1"
},
{
"name": "Zephyr-7B",
"description": "Specialized in understanding context and nuance",
"type": "api",
"model_id": "HuggingFaceH4/zephyr-7b-beta"
}
]
CLASSIFICATION_MODELS = [
{
"name": "Toxic-BERT",
"description": "Fine-tuned for toxic content detection",
"type": "local",
"model_path": "unitary/toxic-bert"
}
]
# グローバル変数でモデルとAPIクライアントを管理
tokenizers = {}
pipelines = {}
api_clients = {}
def initialize_api_clients():
"""Inference APIクライアントの初期化"""
for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
if model["type"] == "api" and "model_id" in model:
logger.info(f"Initializing API client for {model['name']}")
api_clients[model["model_id"]] = InferenceClient(
model["model_id"],
token=True # HFトークンを使用
)
def preload_local_models():
"""ローカルモデルを事前ロード"""
logger.info("Preloading local models at application startup...")
# テキスト生成モデル
for model in TEXT_GENERATION_MODELS:
if model["type"] == "local" and "model_path" in model:
model_path = model["model_path"]
try:
logger.info(f"Preloading text generation model: {model_path}")
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
pipelines[model_path] = pipeline(
"text-generation",
model=model_path,
tokenizer=tokenizers[model_path],
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
logger.info(f"Model preloaded successfully: {model_path}")
except Exception as e:
logger.error(f"Error preloading model {model_path}: {str(e)}")
# 分類モデル
for model in CLASSIFICATION_MODELS:
if model["type"] == "local" and "model_path" in model:
model_path = model["model_path"]
try:
logger.info(f"Preloading classification model: {model_path}")
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
pipelines[model_path] = pipeline(
"text-classification",
model=model_path,
tokenizer=tokenizers[model_path],
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
logger.info(f"Model preloaded successfully: {model_path}")
except Exception as e:
logger.error(f"Error preloading model {model_path}: {str(e)}")
@spaces.GPU
def generate_text_local(model_path, text):
"""ローカルモデルでのテキスト生成"""
try:
logger.info(f"Running local text generation with {model_path}")
outputs = pipelines[model_path](
text,
max_new_tokens=100,
do_sample=False,
num_return_sequences=1
)
return outputs[0]["generated_text"]
except Exception as e:
logger.error(f"Error in local text generation with {model_path}: {str(e)}")
return f"Error: {str(e)}"
def generate_text_api(model_id, text):
"""API経由でのテキスト生成"""
try:
logger.info(f"Running API text generation with {model_id}")
response = api_clients[model_id].text_generation(
text,
max_new_tokens=100,
temperature=0.7
)
return response
except Exception as e:
logger.error(f"Error in API text generation with {model_id}: {str(e)}")
return f"Error: {str(e)}"
@spaces.GPU
def classify_text_local(model_path, text):
"""ローカルモデルでのテキスト分類"""
try:
logger.info(f"Running local classification with {model_path}")
result = pipelines[model_path](text)
return str(result)
except Exception as e:
logger.error(f"Error in local classification with {model_path}: {str(e)}")
return f"Error: {str(e)}"
def classify_text_api(model_id, text):
"""API経由でのテキスト分類"""
try:
logger.info(f"Running API classification with {model_id}")
response = api_clients[model_id].text_classification(text)
return str(response)
except Exception as e:
logger.error(f"Error in API classification with {model_id}: {str(e)}")
return f"Error: {str(e)}"
def handle_invoke(text, selected_types):
"""選択されたタイプのモデルで分析を実行"""
results = []
# テキスト生成モデルの実行
for model in TEXT_GENERATION_MODELS:
if model["type"] in selected_types:
if model["type"] == "local":
result = generate_text_local(model["model_path"], text)
else: # api
result = generate_text_api(model["model_id"], text)
results.append(f"{model['name']}: {result}")
# 分類モデルの実行
for model in CLASSIFICATION_MODELS:
if model["type"] in selected_types:
if model["type"] == "local":
result = classify_text_local(model["model_path"], text)
else: # api
result = classify_text_api(model["model_id"], text)
results.append(f"{model['name']}: {result}")
# 結果リストの長さを調整
while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
results.append("")
return results
def create_ui():
"""UIの作成"""
with gr.Blocks() as demo:
# ヘッダー
gr.Markdown("""
# Toxic Eye (Local + API Version)
This system evaluates the toxicity level of input text using both local models and Inference API.
""")
# 入力セクション
with gr.Row():
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter text to analyze...",
lines=3
)
# フィルターセクション
with gr.Row():
filter_checkboxes = gr.CheckboxGroup(
choices=["local", "api"],
value=["local", "api"],
label="Filter Models",
info="Choose which types of models to use",
interactive=True
)
# 実行ボタン
with gr.Row():
invoke_button = gr.Button(
"Analyze Text",
variant="primary",
size="lg"
)
# モデル出力表示エリア
all_outputs = []
with gr.Tabs():
# テキスト生成モデルのタブ
with gr.Tab("Text Generation Models"):
for model in TEXT_GENERATION_MODELS:
with gr.Group():
gr.Markdown(f"### {model['name']} ({model['type']})")
output = gr.Textbox(
label=f"{model['name']} Output",
lines=5,
interactive=False,
info=model["description"]
)
all_outputs.append(output)
# 分類モデルのタブ
with gr.Tab("Classification Models"):
for model in CLASSIFICATION_MODELS:
with gr.Group():
gr.Markdown(f"### {model['name']} ({model['type']})")
output = gr.Textbox(
label=f"{model['name']} Output",
lines=5,
interactive=False,
info=model["description"]
)
all_outputs.append(output)
# イベント接続
invoke_button.click(
fn=handle_invoke,
inputs=[input_text, filter_checkboxes],
outputs=all_outputs
)
return demo
def main():
# APIクライアントの初期化
initialize_api_clients()
# ローカルモデルを事前ロード
preload_local_models()
# UIを作成して起動
demo = create_ui()
demo.launch()
if __name__ == "__main__":
main()