Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, pipeline | |
from huggingface_hub import InferenceClient | |
import logging | |
import spaces | |
# ロガーの設定 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# モデル定義(ローカルモデルとAPIモデルの両方) | |
TEXT_GENERATION_MODELS = [ | |
{ | |
"name": "Llama-2", | |
"description": "Known for its robust performance in content analysis", | |
"type": "local", | |
"model_path": "meta-llama/Llama-2-7b-hf" | |
}, | |
{ | |
"name": "Mistral-7B", | |
"description": "Offers precise and detailed text evaluation", | |
"type": "local", | |
"model_path": "mistralai/Mistral-7B-v0.1" | |
}, | |
{ | |
"name": "Zephyr-7B", | |
"description": "Specialized in understanding context and nuance", | |
"type": "api", | |
"model_id": "HuggingFaceH4/zephyr-7b-beta" | |
} | |
] | |
CLASSIFICATION_MODELS = [ | |
{ | |
"name": "Toxic-BERT", | |
"description": "Fine-tuned for toxic content detection", | |
"type": "local", | |
"model_path": "unitary/toxic-bert" | |
} | |
] | |
# グローバル変数でモデルとAPIクライアントを管理 | |
tokenizers = {} | |
pipelines = {} | |
api_clients = {} | |
def initialize_api_clients(): | |
"""Inference APIクライアントの初期化""" | |
for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS: | |
if model["type"] == "api" and "model_id" in model: | |
logger.info(f"Initializing API client for {model['name']}") | |
api_clients[model["model_id"]] = InferenceClient( | |
model["model_id"], | |
token=True # HFトークンを使用 | |
) | |
def preload_local_models(): | |
"""ローカルモデルを事前ロード""" | |
logger.info("Preloading local models at application startup...") | |
# テキスト生成モデル | |
for model in TEXT_GENERATION_MODELS: | |
if model["type"] == "local" and "model_path" in model: | |
model_path = model["model_path"] | |
try: | |
logger.info(f"Preloading text generation model: {model_path}") | |
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) | |
pipelines[model_path] = pipeline( | |
"text-generation", | |
model=model_path, | |
tokenizer=tokenizers[model_path], | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
logger.info(f"Model preloaded successfully: {model_path}") | |
except Exception as e: | |
logger.error(f"Error preloading model {model_path}: {str(e)}") | |
# 分類モデル | |
for model in CLASSIFICATION_MODELS: | |
if model["type"] == "local" and "model_path" in model: | |
model_path = model["model_path"] | |
try: | |
logger.info(f"Preloading classification model: {model_path}") | |
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) | |
pipelines[model_path] = pipeline( | |
"text-classification", | |
model=model_path, | |
tokenizer=tokenizers[model_path], | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
logger.info(f"Model preloaded successfully: {model_path}") | |
except Exception as e: | |
logger.error(f"Error preloading model {model_path}: {str(e)}") | |
def generate_text_local(model_path, text): | |
"""ローカルモデルでのテキスト生成""" | |
try: | |
logger.info(f"Running local text generation with {model_path}") | |
outputs = pipelines[model_path]( | |
text, | |
max_new_tokens=100, | |
do_sample=False, | |
num_return_sequences=1 | |
) | |
return outputs[0]["generated_text"] | |
except Exception as e: | |
logger.error(f"Error in local text generation with {model_path}: {str(e)}") | |
return f"Error: {str(e)}" | |
def generate_text_api(model_id, text): | |
"""API経由でのテキスト生成""" | |
try: | |
logger.info(f"Running API text generation with {model_id}") | |
response = api_clients[model_id].text_generation( | |
text, | |
max_new_tokens=100, | |
temperature=0.7 | |
) | |
return response | |
except Exception as e: | |
logger.error(f"Error in API text generation with {model_id}: {str(e)}") | |
return f"Error: {str(e)}" | |
def classify_text_local(model_path, text): | |
"""ローカルモデルでのテキスト分類""" | |
try: | |
logger.info(f"Running local classification with {model_path}") | |
result = pipelines[model_path](text) | |
return str(result) | |
except Exception as e: | |
logger.error(f"Error in local classification with {model_path}: {str(e)}") | |
return f"Error: {str(e)}" | |
def classify_text_api(model_id, text): | |
"""API経由でのテキスト分類""" | |
try: | |
logger.info(f"Running API classification with {model_id}") | |
response = api_clients[model_id].text_classification(text) | |
return str(response) | |
except Exception as e: | |
logger.error(f"Error in API classification with {model_id}: {str(e)}") | |
return f"Error: {str(e)}" | |
def handle_invoke(text, selected_types): | |
"""選択されたタイプのモデルで分析を実行""" | |
results = [] | |
# テキスト生成モデルの実行 | |
for model in TEXT_GENERATION_MODELS: | |
if model["type"] in selected_types: | |
if model["type"] == "local": | |
result = generate_text_local(model["model_path"], text) | |
else: # api | |
result = generate_text_api(model["model_id"], text) | |
results.append(f"{model['name']}: {result}") | |
# 分類モデルの実行 | |
for model in CLASSIFICATION_MODELS: | |
if model["type"] in selected_types: | |
if model["type"] == "local": | |
result = classify_text_local(model["model_path"], text) | |
else: # api | |
result = classify_text_api(model["model_id"], text) | |
results.append(f"{model['name']}: {result}") | |
# 結果リストの長さを調整 | |
while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS): | |
results.append("") | |
return results | |
def create_ui(): | |
"""UIの作成""" | |
with gr.Blocks() as demo: | |
# ヘッダー | |
gr.Markdown(""" | |
# Toxic Eye (Local + API Version) | |
This system evaluates the toxicity level of input text using both local models and Inference API. | |
""") | |
# 入力セクション | |
with gr.Row(): | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Enter text to analyze...", | |
lines=3 | |
) | |
# フィルターセクション | |
with gr.Row(): | |
filter_checkboxes = gr.CheckboxGroup( | |
choices=["local", "api"], | |
value=["local", "api"], | |
label="Filter Models", | |
info="Choose which types of models to use", | |
interactive=True | |
) | |
# 実行ボタン | |
with gr.Row(): | |
invoke_button = gr.Button( | |
"Analyze Text", | |
variant="primary", | |
size="lg" | |
) | |
# モデル出力表示エリア | |
all_outputs = [] | |
with gr.Tabs(): | |
# テキスト生成モデルのタブ | |
with gr.Tab("Text Generation Models"): | |
for model in TEXT_GENERATION_MODELS: | |
with gr.Group(): | |
gr.Markdown(f"### {model['name']} ({model['type']})") | |
output = gr.Textbox( | |
label=f"{model['name']} Output", | |
lines=5, | |
interactive=False, | |
info=model["description"] | |
) | |
all_outputs.append(output) | |
# 分類モデルのタブ | |
with gr.Tab("Classification Models"): | |
for model in CLASSIFICATION_MODELS: | |
with gr.Group(): | |
gr.Markdown(f"### {model['name']} ({model['type']})") | |
output = gr.Textbox( | |
label=f"{model['name']} Output", | |
lines=5, | |
interactive=False, | |
info=model["description"] | |
) | |
all_outputs.append(output) | |
# イベント接続 | |
invoke_button.click( | |
fn=handle_invoke, | |
inputs=[input_text, filter_checkboxes], | |
outputs=all_outputs | |
) | |
return demo | |
def main(): | |
# APIクライアントの初期化 | |
initialize_api_clients() | |
# ローカルモデルを事前ロード | |
preload_local_models() | |
# UIを作成して起動 | |
demo = create_ui() | |
demo.launch() | |
if __name__ == "__main__": | |
main() |