test-zerogpu-2 / app.py
nyasukun's picture
Update app.py
a929439 verified
raw
history blame
3.79 kB
import spaces
import gradio as gr
from transformers import AutoTokenizer, pipeline
import torch
import logging
import asyncio
from functools import partial
# ロギング設定
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# モデル定義
classification_model_name = "unitary/toxic-bert"
generation_model_name = "distilgpt2" # 軽量なテキスト生成モデル
logger.info("Starting model loading...")
# 分類モデルのロード
logger.info(f"Loading classification model: {classification_model_name}")
classification_tokenizer = AutoTokenizer.from_pretrained(classification_model_name)
classification_pipeline = pipeline(
"text-classification",
model=classification_model_name,
tokenizer=classification_tokenizer,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
logger.info(f"Classification model loaded successfully: {classification_model_name}")
# 生成モデルのロード
logger.info(f"Loading generation model: {generation_model_name}")
generation_tokenizer = AutoTokenizer.from_pretrained(generation_model_name)
generation_pipeline = pipeline(
"text-generation",
model=generation_model_name,
tokenizer=generation_tokenizer,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
logger.info(f"Generation model loaded successfully: {generation_model_name}")
# 非同期で分類を実行する関数
async def classify_text_async(prompt):
logger.info(f"Running classification for: {prompt[:50]}...")
# CPUバウンドな処理を非同期実行するためにループの外で実行
loop = asyncio.get_event_loop()
classification_result = await loop.run_in_executor(
None,
lambda: classification_pipeline(prompt)
)
logger.info(f"Classification complete: {classification_result}")
return classification_result
# 非同期で生成を実行する関数
async def generate_text_async(prompt):
logger.info(f"Running text generation for: {prompt[:50]}...")
loop = asyncio.get_event_loop()
generation_result = await loop.run_in_executor(
None,
lambda: generation_pipeline(
prompt,
max_new_tokens=50,
do_sample=True,
temperature=0.7,
num_return_sequences=1
)
)
generated_text = generation_result[0]["generated_text"]
logger.info(f"Text generation complete, generated: {len(generated_text)} chars")
return generated_text
# GPUを利用する非同期推論関数
@spaces.GPU(duration=120)
async def process_text_async(prompt):
logger.info(f"Processing input asynchronously: {prompt[:50]}...")
# 両方のタスクを並行して実行
classification_task = classify_text_async(prompt)
generation_task = generate_text_async(prompt)
# 両方のタスクが完了するのを待つ
classification_result, generated_text = await asyncio.gather(
classification_task,
generation_task
)
# 結果を組み合わせて返す
combined_result = f"分類結果: {classification_result}\n\n生成されたテキスト: {generated_text}"
return combined_result
# Gradioは非同期関数にも対応しているので、そのまま渡す
demo = gr.Interface(
fn=process_text_async, # 非同期関数を使用
inputs=gr.Textbox(lines=3, label="入力テキスト"),
outputs=gr.Textbox(label="処理結果", lines=8),
title="テキスト分類 & 生成デモ (非同期版)",
description="入力テキストに対して分類と生成を非同期で並行実行します。"
)
# アプリの起動
logger.info("Starting application...")
demo.launch()