Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
from transformers import AutoTokenizer, pipeline | |
import torch | |
import logging | |
import asyncio | |
from functools import partial | |
# ロギング設定 | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# モデル定義 | |
classification_model_name = "unitary/toxic-bert" | |
generation_model_name = "distilgpt2" # 軽量なテキスト生成モデル | |
logger.info("Starting model loading...") | |
# 分類モデルのロード | |
logger.info(f"Loading classification model: {classification_model_name}") | |
classification_tokenizer = AutoTokenizer.from_pretrained(classification_model_name) | |
classification_pipeline = pipeline( | |
"text-classification", | |
model=classification_model_name, | |
tokenizer=classification_tokenizer, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
logger.info(f"Classification model loaded successfully: {classification_model_name}") | |
# 生成モデルのロード | |
logger.info(f"Loading generation model: {generation_model_name}") | |
generation_tokenizer = AutoTokenizer.from_pretrained(generation_model_name) | |
generation_pipeline = pipeline( | |
"text-generation", | |
model=generation_model_name, | |
tokenizer=generation_tokenizer, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
logger.info(f"Generation model loaded successfully: {generation_model_name}") | |
# 非同期で分類を実行する関数 | |
async def classify_text_async(prompt): | |
logger.info(f"Running classification for: {prompt[:50]}...") | |
# CPUバウンドな処理を非同期実行するためにループの外で実行 | |
loop = asyncio.get_event_loop() | |
classification_result = await loop.run_in_executor( | |
None, | |
lambda: classification_pipeline(prompt) | |
) | |
logger.info(f"Classification complete: {classification_result}") | |
return classification_result | |
# 非同期で生成を実行する関数 | |
async def generate_text_async(prompt): | |
logger.info(f"Running text generation for: {prompt[:50]}...") | |
loop = asyncio.get_event_loop() | |
generation_result = await loop.run_in_executor( | |
None, | |
lambda: generation_pipeline( | |
prompt, | |
max_new_tokens=50, | |
do_sample=True, | |
temperature=0.7, | |
num_return_sequences=1 | |
) | |
) | |
generated_text = generation_result[0]["generated_text"] | |
logger.info(f"Text generation complete, generated: {len(generated_text)} chars") | |
return generated_text | |
# GPUを利用する非同期推論関数 | |
async def process_text_async(prompt): | |
logger.info(f"Processing input asynchronously: {prompt[:50]}...") | |
# 両方のタスクを並行して実行 | |
classification_task = classify_text_async(prompt) | |
generation_task = generate_text_async(prompt) | |
# 両方のタスクが完了するのを待つ | |
classification_result, generated_text = await asyncio.gather( | |
classification_task, | |
generation_task | |
) | |
# 結果を組み合わせて返す | |
combined_result = f"分類結果: {classification_result}\n\n生成されたテキスト: {generated_text}" | |
return combined_result | |
# Gradioは非同期関数にも対応しているので、そのまま渡す | |
demo = gr.Interface( | |
fn=process_text_async, # 非同期関数を使用 | |
inputs=gr.Textbox(lines=3, label="入力テキスト"), | |
outputs=gr.Textbox(label="処理結果", lines=8), | |
title="テキスト分類 & 生成デモ (非同期版)", | |
description="入力テキストに対して分類と生成を非同期で並行実行します。" | |
) | |
# アプリの起動 | |
logger.info("Starting application...") | |
demo.launch() |