Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, pipeline | |
import logging | |
import spaces | |
# ロガーの設定 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# シンプルなモデル定義(3つのローカルモデル) | |
TEXT_GENERATION_MODELS = [ | |
{ | |
"name": "Llama-2", | |
"description": "Known for its robust performance in content analysis", | |
"model_path": "meta-llama/Llama-2-7b-hf" | |
}, | |
{ | |
"name": "Mistral-7B", | |
"description": "Offers precise and detailed text evaluation", | |
"model_path": "mistralai/Mistral-7B-v0.1" | |
} | |
] | |
CLASSIFICATION_MODELS = [ | |
{ | |
"name": "Toxic-BERT", | |
"description": "Fine-tuned for toxic content detection", | |
"model_path": "unitary/toxic-bert" | |
} | |
] | |
# グローバル変数でモデルとトークナイザを管理 | |
tokenizers = {} | |
pipelines = {} | |
def preload_models(): | |
"""アプリケーション起動時にモデルを事前ロード""" | |
logger.info("Preloading models at application startup...") | |
# テキスト生成モデル | |
for model in TEXT_GENERATION_MODELS: | |
model_path = model["model_path"] | |
try: | |
logger.info(f"Preloading text generation model: {model_path}") | |
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) | |
pipelines[model_path] = pipeline( | |
"text-generation", | |
model=model_path, | |
tokenizer=tokenizers[model_path], | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
logger.info(f"Model preloaded successfully: {model_path}") | |
except Exception as e: | |
logger.error(f"Error preloading model {model_path}: {str(e)}") | |
# 分類モデル | |
for model in CLASSIFICATION_MODELS: | |
model_path = model["model_path"] | |
try: | |
logger.info(f"Preloading classification model: {model_path}") | |
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) | |
pipelines[model_path] = pipeline( | |
"text-classification", | |
model=model_path, | |
tokenizer=tokenizers[model_path], | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
logger.info(f"Model preloaded successfully: {model_path}") | |
except Exception as e: | |
logger.error(f"Error preloading model {model_path}: {str(e)}") | |
def generate_text(model_path, text): | |
"""テキスト生成の実行""" | |
try: | |
logger.info(f"Running text generation with {model_path}") | |
outputs = pipelines[model_path]( | |
text, | |
max_new_tokens=100, | |
do_sample=False, | |
num_return_sequences=1 | |
) | |
return outputs[0]["generated_text"] | |
except Exception as e: | |
logger.error(f"Error in text generation with {model_path}: {str(e)}") | |
return f"Error: {str(e)}" | |
def classify_text(model_path, text): | |
"""テキスト分類の実行""" | |
try: | |
logger.info(f"Running classification with {model_path}") | |
result = pipelines[model_path](text) | |
return str(result) | |
except Exception as e: | |
logger.error(f"Error in classification with {model_path}: {str(e)}") | |
return f"Error: {str(e)}" | |
def handle_invoke(text): | |
"""すべてのモデルで分析を実行""" | |
results = [] | |
# テキスト生成モデルの実行 | |
for model in TEXT_GENERATION_MODELS: | |
model_path = model["model_path"] | |
result = generate_text(model_path, text) | |
results.append(result) | |
# 分類モデルの実行 | |
for model in CLASSIFICATION_MODELS: | |
model_path = model["model_path"] | |
result = classify_text(model_path, text) | |
results.append(result) | |
return results | |
def create_ui(): | |
"""UIの作成""" | |
with gr.Blocks() as demo: | |
# ヘッダー | |
gr.Markdown(""" | |
# Toxic Eye (3 Models Version) | |
This system evaluates the toxicity level of input text using 3 local models. | |
""") | |
# 入力セクション | |
with gr.Row(): | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Enter text to analyze...", | |
lines=3 | |
) | |
# 実行ボタン | |
with gr.Row(): | |
invoke_button = gr.Button( | |
"Analyze Text", | |
variant="primary", | |
size="lg" | |
) | |
# モデル出力表示エリア | |
gen_outputs = [] | |
class_outputs = [] | |
with gr.Tabs(): | |
# テキスト生成モデルのタブ | |
with gr.Tab("Text Generation Models"): | |
for model in TEXT_GENERATION_MODELS: | |
with gr.Group(): | |
gr.Markdown(f"### {model['name']}") | |
output = gr.Textbox( | |
label=f"{model['name']} Output", | |
lines=5, | |
interactive=False, | |
info=model["description"] | |
) | |
gen_outputs.append(output) | |
# 分類モデルのタブ | |
with gr.Tab("Classification Models"): | |
for model in CLASSIFICATION_MODELS: | |
with gr.Group(): | |
gr.Markdown(f"### {model['name']}") | |
output = gr.Textbox( | |
label=f"{model['name']} Output", | |
lines=5, | |
interactive=False, | |
info=model["description"] | |
) | |
class_outputs.append(output) | |
# イベント接続 | |
invoke_button.click( | |
fn=handle_invoke, | |
inputs=[input_text], | |
outputs=gen_outputs + class_outputs | |
) | |
return demo | |
def main(): | |
# モデルを事前ロード | |
preload_models() | |
# UIを作成して起動 | |
demo = create_ui() | |
demo.launch() | |
if __name__ == "__main__": | |
main() |