nyasukun commited on
Commit
1e17339
·
verified ·
1 Parent(s): 970a4b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -97
app.py CHANGED
@@ -1,112 +1,133 @@
1
- import spaces
2
  import gradio as gr
3
- from transformers import AutoTokenizer, pipeline
4
  import torch
 
5
  import logging
6
- import asyncio
7
- from functools import partial
8
 
9
- # ロギング設定
10
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
- # モデル定義
14
- classification_model_name = "unitary/toxic-bert"
15
- generation_model_name = "distilgpt2" # 軽量なテキスト生成モデル
16
 
17
- logger.info("Starting model loading...")
18
-
19
- # 分類モデルのロード
20
- logger.info(f"Loading classification model: {classification_model_name}")
21
- classification_tokenizer = AutoTokenizer.from_pretrained(classification_model_name)
22
- classification_pipeline = pipeline(
23
- "text-classification",
24
- model=classification_model_name,
25
- tokenizer=classification_tokenizer,
26
- torch_dtype=torch.bfloat16,
27
- trust_remote_code=True,
28
- device_map="auto"
29
- )
30
- logger.info(f"Classification model loaded successfully: {classification_model_name}")
31
-
32
- # 生成モデルのロード
33
- logger.info(f"Loading generation model: {generation_model_name}")
34
- generation_tokenizer = AutoTokenizer.from_pretrained(generation_model_name)
35
- generation_pipeline = pipeline(
36
- "text-generation",
37
- model=generation_model_name,
38
- tokenizer=generation_tokenizer,
39
- torch_dtype=torch.bfloat16,
40
- trust_remote_code=True,
41
- device_map="auto"
42
- )
43
- logger.info(f"Generation model loaded successfully: {generation_model_name}")
44
-
45
- # GPUを利用する同期処理関数(分類)
46
- @spaces.GPU(duration=60)
47
- def classify_text(prompt):
48
- logger.info(f"Running classification for: {prompt[:50]}...")
49
- classification_result = classification_pipeline(prompt)
50
- logger.info(f"Classification complete: {classification_result}")
51
- return classification_result
52
 
53
- # GPUを利用する同期処理関数(生成)
54
- @spaces.GPU(duration=60)
55
- def generate_text(prompt):
56
- logger.info(f"Running text generation for: {prompt[:50]}...")
57
- generation_result = generation_pipeline(
58
- prompt,
59
- max_new_tokens=50,
60
- do_sample=True,
61
- temperature=0.7,
62
- num_return_sequences=1
63
  )
64
- generated_text = generation_result[0]["generated_text"]
65
- logger.info(f"Text generation complete, generated: {len(generated_text)} chars")
66
- return generated_text
67
 
68
- # 非同期ラッパー関数
69
- async def classify_text_async(prompt):
70
- loop = asyncio.get_event_loop()
71
- return await loop.run_in_executor(None, lambda: classify_text(prompt))
72
 
73
- async def generate_text_async(prompt):
74
- loop = asyncio.get_event_loop()
75
- return await loop.run_in_executor(None, lambda: generate_text(prompt))
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # メイン処理を行う非同期関数
78
- async def process_text_async(prompt):
79
- logger.info(f"Processing input asynchronously: {prompt[:50]}...")
80
-
81
- # 両方のタスクを並行して実行
82
- classification_task = classify_text_async(prompt)
83
- generation_task = generate_text_async(prompt)
84
-
85
- # 両方のタスクが完了するのを待つ
86
- classification_result, generated_text = await asyncio.gather(
87
- classification_task,
88
- generation_task
89
- )
90
-
91
- # 結果を組み合わせて返す
92
- combined_result = f"分類結果: {classification_result}\n\n生成されたテキスト: {generated_text}"
93
- return combined_result
94
 
95
- # Gradio用の同期ラッパー関数
96
- def process_text(prompt):
97
- # 非同期関数を同期的に実行
98
- loop = asyncio.get_event_loop()
99
- return loop.run_until_complete(process_text_async(prompt))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- # Gradioインターフェース
102
- demo = gr.Interface(
103
- fn=process_text, # 同期ラッパー関数を使用
104
- inputs=gr.Textbox(lines=3, label="入力テキスト"),
105
- outputs=gr.Textbox(label="処理結果", lines=8),
106
- title="テキスト分類 & 生成デモ (非同期処理版)",
107
- description="入力テキストに対して分類と生成を非同期で並行実行します。"
108
- )
109
 
110
- # アプリの起動
111
- logger.info("Starting application...")
112
- demo.launch()
 
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
  import torch
4
+ from transformers import AutoTokenizer, pipeline
5
  import logging
6
+ import spaces
 
7
 
8
+ # ロガーの設定
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12
+ )
13
  logger = logging.getLogger(__name__)
14
 
15
+ # モデル定義(シンプルに2つだけ)
16
+ TEXT_GENERATION_MODEL = "mistralai/Mistral-7B-v0.1"
17
+ CLASSIFICATION_MODEL = "unitary/toxic-bert"
18
 
19
+ # モデルを事前ロード
20
+ def load_text_generation_model():
21
+ logger.info(f"Loading text generation model: {TEXT_GENERATION_MODEL}")
22
+ tokenizer = AutoTokenizer.from_pretrained(TEXT_GENERATION_MODEL)
23
+ text_pipeline = pipeline(
24
+ "text-generation",
25
+ model=TEXT_GENERATION_MODEL,
26
+ tokenizer=tokenizer,
27
+ torch_dtype=torch.bfloat16,
28
+ trust_remote_code=True,
29
+ device_map="auto"
30
+ )
31
+ return text_pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def load_classification_model():
34
+ logger.info(f"Loading classification model: {CLASSIFICATION_MODEL}")
35
+ tokenizer = AutoTokenizer.from_pretrained(CLASSIFICATION_MODEL)
36
+ class_pipeline = pipeline(
37
+ "text-classification",
38
+ model=CLASSIFICATION_MODEL,
39
+ tokenizer=tokenizer,
40
+ torch_dtype=torch.bfloat16,
41
+ trust_remote_code=True,
42
+ device_map="auto"
43
  )
44
+ return class_pipeline
 
 
45
 
46
+ # モデルのロード
47
+ text_gen_pipeline = load_text_generation_model()
48
+ classification_pipeline = load_classification_model()
 
49
 
50
+ # テキスト生成関数
51
+ @spaces.GPU
52
+ def generate_text(text):
53
+ try:
54
+ logger.info("Running text generation")
55
+ outputs = text_gen_pipeline(
56
+ text,
57
+ max_new_tokens=100,
58
+ do_sample=False,
59
+ num_return_sequences=1
60
+ )
61
+ return outputs[0]["generated_text"]
62
+ except Exception as e:
63
+ logger.error(f"Error in text generation: {str(e)}")
64
+ return f"Error: {str(e)}"
65
 
66
+ # テキスト分類関数
67
+ @spaces.GPU
68
+ def classify_text(text):
69
+ try:
70
+ logger.info("Running classification")
71
+ result = classification_pipeline(text)
72
+ return str(result)
73
+ except Exception as e:
74
+ logger.error(f"Error in classification: {str(e)}")
75
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
76
 
77
+ # UIとロジックをシンプルに統合
78
+ def create_ui():
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown("""
81
+ # Toxic Eye (Simple Version)
82
+ This system evaluates the toxicity level of input text using two models.
83
+ """)
84
+
85
+ with gr.Row():
86
+ input_text = gr.Textbox(
87
+ label="Input Text",
88
+ placeholder="Enter text to analyze...",
89
+ lines=3
90
+ )
91
+
92
+ with gr.Row():
93
+ invoke_button = gr.Button(
94
+ "Analyze Text",
95
+ variant="primary",
96
+ size="lg"
97
+ )
98
+
99
+ with gr.Tabs():
100
+ with gr.Tab("Text Generation"):
101
+ gen_output = gr.Textbox(
102
+ label="Mistral-7B Output",
103
+ lines=5,
104
+ interactive=False
105
+ )
106
+
107
+ with gr.Tab("Classification"):
108
+ class_output = gr.Textbox(
109
+ label="Toxic-BERT Output",
110
+ lines=5,
111
+ interactive=False
112
+ )
113
+
114
+ # イベントハンドラ
115
+ def handle_invoke(text):
116
+ gen_result = generate_text(text)
117
+ class_result = classify_text(text)
118
+ return gen_result, class_result
119
+
120
+ invoke_button.click(
121
+ fn=handle_invoke,
122
+ inputs=[input_text],
123
+ outputs=[gen_output, class_output]
124
+ )
125
+
126
+ return demo
127
 
128
+ def main():
129
+ demo = create_ui()
130
+ demo.launch()
 
 
 
 
 
131
 
132
+ if __name__ == "__main__":
133
+ main()