import gradio as gr import torch from transformers import AutoTokenizer, pipeline from huggingface_hub import InferenceClient import logging import spaces # ロガーの設定 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # モデル定義(ローカルモデルとAPIモデルの両方) TEXT_GENERATION_MODELS = [ { "name": "Llama-2", "description": "Known for its robust performance in content analysis", "type": "local", "model_path": "meta-llama/Llama-2-7b-hf" }, { "name": "Mistral-7B", "description": "Offers precise and detailed text evaluation", "type": "local", "model_path": "mistralai/Mistral-7B-v0.1" }, { "name": "Zephyr-7B", "description": "Specialized in understanding context and nuance", "type": "api", "model_id": "HuggingFaceH4/zephyr-7b-beta" } ] CLASSIFICATION_MODELS = [ { "name": "Toxic-BERT", "description": "Fine-tuned for toxic content detection", "type": "local", "model_path": "unitary/toxic-bert" } ] # グローバル変数でモデルとAPIクライアントを管理 tokenizers = {} pipelines = {} api_clients = {} def initialize_api_clients(): """Inference APIクライアントの初期化""" for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS: if model["type"] == "api" and "model_id" in model: logger.info(f"Initializing API client for {model['name']}") api_clients[model["model_id"]] = InferenceClient( model["model_id"], token=True # HFトークンを使用 ) def preload_local_models(): """ローカルモデルを事前ロード""" logger.info("Preloading local models at application startup...") # テキスト生成モデル for model in TEXT_GENERATION_MODELS: if model["type"] == "local" and "model_path" in model: model_path = model["model_path"] try: logger.info(f"Preloading text generation model: {model_path}") tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) pipelines[model_path] = pipeline( "text-generation", model=model_path, tokenizer=tokenizers[model_path], torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" ) logger.info(f"Model preloaded successfully: {model_path}") except Exception as e: logger.error(f"Error preloading model {model_path}: {str(e)}") # 分類モデル for model in CLASSIFICATION_MODELS: if model["type"] == "local" and "model_path" in model: model_path = model["model_path"] try: logger.info(f"Preloading classification model: {model_path}") tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) pipelines[model_path] = pipeline( "text-classification", model=model_path, tokenizer=tokenizers[model_path], torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" ) logger.info(f"Model preloaded successfully: {model_path}") except Exception as e: logger.error(f"Error preloading model {model_path}: {str(e)}") @spaces.GPU def generate_text_local(model_path, text): """ローカルモデルでのテキスト生成""" try: logger.info(f"Running local text generation with {model_path}") outputs = pipelines[model_path]( text, max_new_tokens=100, do_sample=False, num_return_sequences=1 ) return outputs[0]["generated_text"] except Exception as e: logger.error(f"Error in local text generation with {model_path}: {str(e)}") return f"Error: {str(e)}" def generate_text_api(model_id, text): """API経由でのテキスト生成""" try: logger.info(f"Running API text generation with {model_id}") response = api_clients[model_id].text_generation( text, max_new_tokens=100, temperature=0.7 ) return response except Exception as e: logger.error(f"Error in API text generation with {model_id}: {str(e)}") return f"Error: {str(e)}" @spaces.GPU def classify_text_local(model_path, text): """ローカルモデルでのテキスト分類""" try: logger.info(f"Running local classification with {model_path}") result = pipelines[model_path](text) return str(result) except Exception as e: logger.error(f"Error in local classification with {model_path}: {str(e)}") return f"Error: {str(e)}" def classify_text_api(model_id, text): """API経由でのテキスト分類""" try: logger.info(f"Running API classification with {model_id}") response = api_clients[model_id].text_classification(text) return str(response) except Exception as e: logger.error(f"Error in API classification with {model_id}: {str(e)}") return f"Error: {str(e)}" def handle_invoke(text, selected_types): """選択されたタイプのモデルで分析を実行""" results = [] # テキスト生成モデルの実行 for model in TEXT_GENERATION_MODELS: if model["type"] in selected_types: if model["type"] == "local": result = generate_text_local(model["model_path"], text) else: # api result = generate_text_api(model["model_id"], text) results.append(f"{model['name']}: {result}") # 分類モデルの実行 for model in CLASSIFICATION_MODELS: if model["type"] in selected_types: if model["type"] == "local": result = classify_text_local(model["model_path"], text) else: # api result = classify_text_api(model["model_id"], text) results.append(f"{model['name']}: {result}") # 結果リストの長さを調整 while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS): results.append("") return results def create_ui(): """UIの作成""" with gr.Blocks() as demo: # ヘッダー gr.Markdown(""" # Toxic Eye (Local + API Version) This system evaluates the toxicity level of input text using both local models and Inference API. """) # 入力セクション with gr.Row(): input_text = gr.Textbox( label="Input Text", placeholder="Enter text to analyze...", lines=3 ) # フィルターセクション with gr.Row(): filter_checkboxes = gr.CheckboxGroup( choices=["local", "api"], value=["local", "api"], label="Filter Models", info="Choose which types of models to use", interactive=True ) # 実行ボタン with gr.Row(): invoke_button = gr.Button( "Analyze Text", variant="primary", size="lg" ) # モデル出力表示エリア all_outputs = [] with gr.Tabs(): # テキスト生成モデルのタブ with gr.Tab("Text Generation Models"): for model in TEXT_GENERATION_MODELS: with gr.Group(): gr.Markdown(f"### {model['name']} ({model['type']})") output = gr.Textbox( label=f"{model['name']} Output", lines=5, interactive=False, info=model["description"] ) all_outputs.append(output) # 分類モデルのタブ with gr.Tab("Classification Models"): for model in CLASSIFICATION_MODELS: with gr.Group(): gr.Markdown(f"### {model['name']} ({model['type']})") output = gr.Textbox( label=f"{model['name']} Output", lines=5, interactive=False, info=model["description"] ) all_outputs.append(output) # イベント接続 invoke_button.click( fn=handle_invoke, inputs=[input_text, filter_checkboxes], outputs=all_outputs ) return demo def main(): # APIクライアントの初期化 initialize_api_clients() # ローカルモデルを事前ロード preload_local_models() # UIを作成して起動 demo = create_ui() demo.launch() if __name__ == "__main__": main()