nyasukun commited on
Commit
b37f1c5
·
verified ·
1 Parent(s): 1e17339

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -64
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
  import torch
4
  from transformers import AutoTokenizer, pipeline
5
  import logging
@@ -12,47 +11,78 @@ logging.basicConfig(
12
  )
13
  logger = logging.getLogger(__name__)
14
 
15
- # モデル定義(シンプルに2つだけ)
16
- TEXT_GENERATION_MODEL = "mistralai/Mistral-7B-v0.1"
17
- CLASSIFICATION_MODEL = "unitary/toxic-bert"
 
 
 
 
 
 
 
 
 
 
18
 
19
- # モデルを事前ロード
20
- def load_text_generation_model():
21
- logger.info(f"Loading text generation model: {TEXT_GENERATION_MODEL}")
22
- tokenizer = AutoTokenizer.from_pretrained(TEXT_GENERATION_MODEL)
23
- text_pipeline = pipeline(
24
- "text-generation",
25
- model=TEXT_GENERATION_MODEL,
26
- tokenizer=tokenizer,
27
- torch_dtype=torch.bfloat16,
28
- trust_remote_code=True,
29
- device_map="auto"
30
- )
31
- return text_pipeline
32
 
33
- def load_classification_model():
34
- logger.info(f"Loading classification model: {CLASSIFICATION_MODEL}")
35
- tokenizer = AutoTokenizer.from_pretrained(CLASSIFICATION_MODEL)
36
- class_pipeline = pipeline(
37
- "text-classification",
38
- model=CLASSIFICATION_MODEL,
39
- tokenizer=tokenizer,
40
- torch_dtype=torch.bfloat16,
41
- trust_remote_code=True,
42
- device_map="auto"
43
- )
44
- return class_pipeline
45
 
46
- # モデルのロード
47
- text_gen_pipeline = load_text_generation_model()
48
- classification_pipeline = load_classification_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # テキスト生成関数
51
  @spaces.GPU
52
- def generate_text(text):
 
53
  try:
54
- logger.info("Running text generation")
55
- outputs = text_gen_pipeline(
56
  text,
57
  max_new_tokens=100,
58
  do_sample=False,
@@ -60,28 +90,48 @@ def generate_text(text):
60
  )
61
  return outputs[0]["generated_text"]
62
  except Exception as e:
63
- logger.error(f"Error in text generation: {str(e)}")
64
  return f"Error: {str(e)}"
65
 
66
- # テキスト分類関数
67
  @spaces.GPU
68
- def classify_text(text):
 
69
  try:
70
- logger.info("Running classification")
71
- result = classification_pipeline(text)
72
  return str(result)
73
  except Exception as e:
74
- logger.error(f"Error in classification: {str(e)}")
75
  return f"Error: {str(e)}"
76
 
77
- # UIとロジックをシンプルに統合
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def create_ui():
 
79
  with gr.Blocks() as demo:
 
80
  gr.Markdown("""
81
- # Toxic Eye (Simple Version)
82
- This system evaluates the toxicity level of input text using two models.
83
  """)
84
 
 
85
  with gr.Row():
86
  input_text = gr.Textbox(
87
  label="Input Text",
@@ -89,6 +139,7 @@ def create_ui():
89
  lines=3
90
  )
91
 
 
92
  with gr.Row():
93
  invoke_button = gr.Button(
94
  "Analyze Text",
@@ -96,36 +147,51 @@ def create_ui():
96
  size="lg"
97
  )
98
 
 
 
 
 
99
  with gr.Tabs():
100
- with gr.Tab("Text Generation"):
101
- gen_output = gr.Textbox(
102
- label="Mistral-7B Output",
103
- lines=5,
104
- interactive=False
105
- )
 
 
 
 
 
 
106
 
107
- with gr.Tab("Classification"):
108
- class_output = gr.Textbox(
109
- label="Toxic-BERT Output",
110
- lines=5,
111
- interactive=False
112
- )
113
-
114
- # イベントハンドラ
115
- def handle_invoke(text):
116
- gen_result = generate_text(text)
117
- class_result = classify_text(text)
118
- return gen_result, class_result
119
 
 
120
  invoke_button.click(
121
  fn=handle_invoke,
122
  inputs=[input_text],
123
- outputs=[gen_output, class_output]
124
  )
125
 
126
  return demo
127
 
128
  def main():
 
 
 
 
129
  demo = create_ui()
130
  demo.launch()
131
 
 
1
  import gradio as gr
 
2
  import torch
3
  from transformers import AutoTokenizer, pipeline
4
  import logging
 
11
  )
12
  logger = logging.getLogger(__name__)
13
 
14
+ # シンプルなモデル定義(3つのローカルモデル)
15
+ TEXT_GENERATION_MODELS = [
16
+ {
17
+ "name": "Llama-2",
18
+ "description": "Known for its robust performance in content analysis",
19
+ "model_path": "meta-llama/Llama-2-7b-hf"
20
+ },
21
+ {
22
+ "name": "Mistral-7B",
23
+ "description": "Offers precise and detailed text evaluation",
24
+ "model_path": "mistralai/Mistral-7B-v0.1"
25
+ }
26
+ ]
27
 
28
+ CLASSIFICATION_MODELS = [
29
+ {
30
+ "name": "Toxic-BERT",
31
+ "description": "Fine-tuned for toxic content detection",
32
+ "model_path": "unitary/toxic-bert"
33
+ }
34
+ ]
 
 
 
 
 
 
35
 
36
+ # グローバル変数でモデルとトークナイザを管理
37
+ tokenizers = {}
38
+ pipelines = {}
 
 
 
 
 
 
 
 
 
39
 
40
+ def preload_models():
41
+ """アプリケーション起動時にモデルを事前ロード"""
42
+ logger.info("Preloading models at application startup...")
43
+
44
+ # テキスト生成モデル
45
+ for model in TEXT_GENERATION_MODELS:
46
+ model_path = model["model_path"]
47
+ try:
48
+ logger.info(f"Preloading text generation model: {model_path}")
49
+ tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
50
+ pipelines[model_path] = pipeline(
51
+ "text-generation",
52
+ model=model_path,
53
+ tokenizer=tokenizers[model_path],
54
+ torch_dtype=torch.bfloat16,
55
+ trust_remote_code=True,
56
+ device_map="auto"
57
+ )
58
+ logger.info(f"Model preloaded successfully: {model_path}")
59
+ except Exception as e:
60
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
61
+
62
+ # 分類モデル
63
+ for model in CLASSIFICATION_MODELS:
64
+ model_path = model["model_path"]
65
+ try:
66
+ logger.info(f"Preloading classification model: {model_path}")
67
+ tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
68
+ pipelines[model_path] = pipeline(
69
+ "text-classification",
70
+ model=model_path,
71
+ tokenizer=tokenizers[model_path],
72
+ torch_dtype=torch.bfloat16,
73
+ trust_remote_code=True,
74
+ device_map="auto"
75
+ )
76
+ logger.info(f"Model preloaded successfully: {model_path}")
77
+ except Exception as e:
78
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
79
 
 
80
  @spaces.GPU
81
+ def generate_text(model_path, text):
82
+ """テキスト生成の実行"""
83
  try:
84
+ logger.info(f"Running text generation with {model_path}")
85
+ outputs = pipelines[model_path](
86
  text,
87
  max_new_tokens=100,
88
  do_sample=False,
 
90
  )
91
  return outputs[0]["generated_text"]
92
  except Exception as e:
93
+ logger.error(f"Error in text generation with {model_path}: {str(e)}")
94
  return f"Error: {str(e)}"
95
 
 
96
  @spaces.GPU
97
+ def classify_text(model_path, text):
98
+ """テキスト分類の実行"""
99
  try:
100
+ logger.info(f"Running classification with {model_path}")
101
+ result = pipelines[model_path](text)
102
  return str(result)
103
  except Exception as e:
104
+ logger.error(f"Error in classification with {model_path}: {str(e)}")
105
  return f"Error: {str(e)}"
106
 
107
+ def handle_invoke(text):
108
+ """すべてのモデルで分析を実行"""
109
+ results = []
110
+
111
+ # テキスト生成モデルの実行
112
+ for model in TEXT_GENERATION_MODELS:
113
+ model_path = model["model_path"]
114
+ result = generate_text(model_path, text)
115
+ results.append(result)
116
+
117
+ # 分類モデルの実行
118
+ for model in CLASSIFICATION_MODELS:
119
+ model_path = model["model_path"]
120
+ result = classify_text(model_path, text)
121
+ results.append(result)
122
+
123
+ return results
124
+
125
  def create_ui():
126
+ """UIの作成"""
127
  with gr.Blocks() as demo:
128
+ # ヘッダー
129
  gr.Markdown("""
130
+ # Toxic Eye (3 Models Version)
131
+ This system evaluates the toxicity level of input text using 3 local models.
132
  """)
133
 
134
+ # 入力セクション
135
  with gr.Row():
136
  input_text = gr.Textbox(
137
  label="Input Text",
 
139
  lines=3
140
  )
141
 
142
+ # 実行ボタン
143
  with gr.Row():
144
  invoke_button = gr.Button(
145
  "Analyze Text",
 
147
  size="lg"
148
  )
149
 
150
+ # モデル出力表示エリア
151
+ gen_outputs = []
152
+ class_outputs = []
153
+
154
  with gr.Tabs():
155
+ # テキスト生成モデルのタブ
156
+ with gr.Tab("Text Generation Models"):
157
+ for model in TEXT_GENERATION_MODELS:
158
+ with gr.Group():
159
+ gr.Markdown(f"### {model['name']}")
160
+ output = gr.Textbox(
161
+ label=f"{model['name']} Output",
162
+ lines=5,
163
+ interactive=False,
164
+ info=model["description"]
165
+ )
166
+ gen_outputs.append(output)
167
 
168
+ # 分類モデルのタブ
169
+ with gr.Tab("Classification Models"):
170
+ for model in CLASSIFICATION_MODELS:
171
+ with gr.Group():
172
+ gr.Markdown(f"### {model['name']}")
173
+ output = gr.Textbox(
174
+ label=f"{model['name']} Output",
175
+ lines=5,
176
+ interactive=False,
177
+ info=model["description"]
178
+ )
179
+ class_outputs.append(output)
180
 
181
+ # イベント接続
182
  invoke_button.click(
183
  fn=handle_invoke,
184
  inputs=[input_text],
185
+ outputs=gen_outputs + class_outputs
186
  )
187
 
188
  return demo
189
 
190
  def main():
191
+ # モデルを事前ロード
192
+ preload_models()
193
+
194
+ # UIを作成して起動
195
  demo = create_ui()
196
  demo.launch()
197