nyasukun commited on
Commit
08d5a0c
·
verified ·
1 Parent(s): e09e1bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -202
app.py CHANGED
@@ -43,226 +43,240 @@ CLASSIFICATION_MODELS = [
43
  }
44
  ]
45
 
46
- # グローバル変数でモデルとAPIクライアントを管理
47
- tokenizers = {}
48
- pipelines = {}
49
- api_clients = {}
 
 
 
50
 
51
- def initialize_api_clients():
52
- """Inference APIクライアントの初期化"""
53
- for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
54
- if model["type"] == "api" and "model_id" in model:
55
- logger.info(f"Initializing API client for {model['name']}")
56
- api_clients[model["model_id"]] = InferenceClient(
57
- model["model_id"],
58
- token=True # HFトークンを使用
59
- )
60
-
61
- def preload_local_models():
62
- """ローカルモデルを事前ロード"""
63
- logger.info("Preloading local models at application startup...")
64
-
65
- # テキスト生成モデル
66
- for model in TEXT_GENERATION_MODELS:
67
- if model["type"] == "local" and "model_path" in model:
68
- model_path = model["model_path"]
69
- try:
70
- logger.info(f"Preloading text generation model: {model_path}")
71
- tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
72
- pipelines[model_path] = pipeline(
73
- "text-generation",
74
- model=model_path,
75
- tokenizer=tokenizers[model_path],
76
- torch_dtype=torch.bfloat16,
77
- trust_remote_code=True,
78
- device_map="auto"
79
- )
80
- logger.info(f"Model preloaded successfully: {model_path}")
81
- except Exception as e:
82
- logger.error(f"Error preloading model {model_path}: {str(e)}")
83
-
84
- # 分類モデル
85
- for model in CLASSIFICATION_MODELS:
86
- if model["type"] == "local" and "model_path" in model:
87
- model_path = model["model_path"]
88
- try:
89
- logger.info(f"Preloading classification model: {model_path}")
90
- tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
91
- pipelines[model_path] = pipeline(
92
- "text-classification",
93
- model=model_path,
94
- tokenizer=tokenizers[model_path],
95
- torch_dtype=torch.bfloat16,
96
- trust_remote_code=True,
97
- device_map="auto"
98
  )
99
- logger.info(f"Model preloaded successfully: {model_path}")
100
- except Exception as e:
101
- logger.error(f"Error preloading model {model_path}: {str(e)}")
102
 
103
- @spaces.GPU
104
- def generate_text_local(model_path, text):
105
- """ローカルモデルでのテキスト生成"""
106
- try:
107
- logger.info(f"Running local text generation with {model_path}")
108
- outputs = pipelines[model_path](
109
- text,
110
- max_new_tokens=100,
111
- do_sample=False,
112
- num_return_sequences=1
113
- )
114
- return outputs[0]["generated_text"]
115
- except Exception as e:
116
- logger.error(f"Error in local text generation with {model_path}: {str(e)}")
117
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- def generate_text_api(model_id, text):
120
- """API経由でのテキスト生成"""
121
- try:
122
- logger.info(f"Running API text generation with {model_id}")
123
- response = api_clients[model_id].text_generation(
124
- text,
125
- max_new_tokens=100,
126
- temperature=0.7
127
- )
128
- return response
129
- except Exception as e:
130
- logger.error(f"Error in API text generation with {model_id}: {str(e)}")
131
- return f"Error: {str(e)}"
 
 
132
 
133
- @spaces.GPU
134
- def classify_text_local(model_path, text):
135
- """ローカルモデルでのテキスト分類"""
136
- try:
137
- logger.info(f"Running local classification with {model_path}")
138
- result = pipelines[model_path](text)
139
- return str(result)
140
- except Exception as e:
141
- logger.error(f"Error in local classification with {model_path}: {str(e)}")
142
- return f"Error: {str(e)}"
 
 
 
143
 
144
- def classify_text_api(model_id, text):
145
- """API経由でのテキスト分類"""
146
- try:
147
- logger.info(f"Running API classification with {model_id}")
148
- response = api_clients[model_id].text_classification(text)
149
- return str(response)
150
- except Exception as e:
151
- logger.error(f"Error in API classification with {model_id}: {str(e)}")
152
- return f"Error: {str(e)}"
 
153
 
154
- def handle_invoke(text, selected_types):
155
- """選択されたタイプのモデルで分析を実行"""
156
- results = []
157
-
158
- # テキスト生成モデルの実行
159
- for model in TEXT_GENERATION_MODELS:
160
- if model["type"] in selected_types:
161
- if model["type"] == "local":
162
- result = generate_text_local(model["model_path"], text)
163
- else: # api
164
- result = generate_text_api(model["model_id"], text)
165
- results.append(f"{model['name']}: {result}")
166
-
167
- # 分類モデルの実行
168
- for model in CLASSIFICATION_MODELS:
169
- if model["type"] in selected_types:
170
- if model["type"] == "local":
171
- result = classify_text_local(model["model_path"], text)
172
- else: # api
173
- result = classify_text_api(model["model_id"], text)
174
- results.append(f"{model['name']}: {result}")
175
-
176
- # 結果リストの長さを調整
177
- while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
178
- results.append("")
179
-
180
- return results
181
 
182
- def create_ui():
183
- """UIの作成"""
184
- with gr.Blocks() as demo:
185
- # ヘッダー
186
- gr.Markdown("""
187
- # Toxic Eye (Local + API Version)
188
- This system evaluates the toxicity level of input text using both local models and Inference API.
189
- """)
190
 
191
- # 入力セクション
192
- with gr.Row():
193
- input_text = gr.Textbox(
194
- label="Input Text",
195
- placeholder="Enter text to analyze...",
196
- lines=3
197
- )
 
198
 
199
- # フィルターセクション
200
- with gr.Row():
201
- filter_checkboxes = gr.CheckboxGroup(
202
- choices=["local", "api"],
203
- value=["local", "api"],
204
- label="Filter Models",
205
- info="Choose which types of models to use",
206
- interactive=True
207
- )
208
 
209
- # 実行ボタン
210
- with gr.Row():
211
- invoke_button = gr.Button(
212
- "Analyze Text",
213
- variant="primary",
214
- size="lg"
215
- )
216
 
217
- # モデル出力表示エリア
218
- all_outputs = []
 
 
 
219
 
220
- with gr.Tabs():
221
- # テキスト生成モデルのタブ
222
- with gr.Tab("Text Generation Models"):
223
- for model in TEXT_GENERATION_MODELS:
224
- with gr.Group():
225
- gr.Markdown(f"### {model['name']} ({model['type']})")
226
- output = gr.Textbox(
227
- label=f"{model['name']} Output",
228
- lines=5,
229
- interactive=False,
230
- info=model["description"]
231
- )
232
- all_outputs.append(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- # 分類モデルのタブ
235
- with gr.Tab("Classification Models"):
236
- for model in CLASSIFICATION_MODELS:
237
- with gr.Group():
238
- gr.Markdown(f"### {model['name']} ({model['type']})")
239
- output = gr.Textbox(
240
- label=f"{model['name']} Output",
241
- lines=5,
242
- interactive=False,
243
- info=model["description"]
244
- )
245
- all_outputs.append(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- # イベント接続
248
- invoke_button.click(
249
- fn=handle_invoke,
250
- inputs=[input_text, filter_checkboxes],
251
- outputs=all_outputs
252
- )
253
 
254
- return demo
 
 
255
 
256
- def main():
257
- # APIクライアントの初期化
258
- initialize_api_clients()
 
259
 
260
- # ローカルモデルを事前ロード
261
- preload_local_models()
262
-
263
- # UIを作成して起動
264
- demo = create_ui()
265
- demo.launch()
 
 
266
 
267
  if __name__ == "__main__":
268
  main()
 
43
  }
44
  ]
45
 
46
+ class ModelManager:
47
+ def __init__(self):
48
+ self.tokenizers = {}
49
+ self.pipelines = {}
50
+ self.api_clients = {}
51
+ self._initialize_api_clients()
52
+ self._preload_local_models()
53
 
54
+ def _initialize_api_clients(self):
55
+ """Inference APIクライアントの初期化"""
56
+ for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
57
+ if model["type"] == "api" and "model_id" in model:
58
+ logger.info(f"Initializing API client for {model['name']}")
59
+ self.api_clients[model["model_id"]] = InferenceClient(
60
+ model["model_id"],
61
+ token=True # HFトークンを使用
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
 
 
 
63
 
64
+ def _preload_local_models(self):
65
+ """ローカルモデルを事前ロード"""
66
+ logger.info("Preloading local models at application startup...")
67
+
68
+ # テキスト生成モデル
69
+ for model in TEXT_GENERATION_MODELS:
70
+ if model["type"] == "local" and "model_path" in model:
71
+ model_path = model["model_path"]
72
+ try:
73
+ logger.info(f"Preloading text generation model: {model_path}")
74
+ self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
75
+ self.pipelines[model_path] = pipeline(
76
+ "text-generation",
77
+ model=model_path,
78
+ tokenizer=self.tokenizers[model_path],
79
+ torch_dtype=torch.bfloat16,
80
+ trust_remote_code=True,
81
+ device_map="auto"
82
+ )
83
+ logger.info(f"Model preloaded successfully: {model_path}")
84
+ except Exception as e:
85
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
86
+
87
+ # 分類モデル
88
+ for model in CLASSIFICATION_MODELS:
89
+ if model["type"] == "local" and "model_path" in model:
90
+ model_path = model["model_path"]
91
+ try:
92
+ logger.info(f"Preloading classification model: {model_path}")
93
+ self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
94
+ self.pipelines[model_path] = pipeline(
95
+ "text-classification",
96
+ model=model_path,
97
+ tokenizer=self.tokenizers[model_path],
98
+ torch_dtype=torch.bfloat16,
99
+ trust_remote_code=True,
100
+ device_map="auto"
101
+ )
102
+ logger.info(f"Model preloaded successfully: {model_path}")
103
+ except Exception as e:
104
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
105
 
106
+ @spaces.GPU
107
+ def generate_text_local(self, model_path, text):
108
+ """ローカルモデルでのテキスト生成"""
109
+ try:
110
+ logger.info(f"Running local text generation with {model_path}")
111
+ outputs = self.pipelines[model_path](
112
+ text,
113
+ max_new_tokens=100,
114
+ do_sample=False,
115
+ num_return_sequences=1
116
+ )
117
+ return outputs[0]["generated_text"]
118
+ except Exception as e:
119
+ logger.error(f"Error in local text generation with {model_path}: {str(e)}")
120
+ return f"Error: {str(e)}"
121
 
122
+ def generate_text_api(self, model_id, text):
123
+ """API経由でのテキスト生成"""
124
+ try:
125
+ logger.info(f"Running API text generation with {model_id}")
126
+ response = self.api_clients[model_id].text_generation(
127
+ text,
128
+ max_new_tokens=100,
129
+ temperature=0.7
130
+ )
131
+ return response
132
+ except Exception as e:
133
+ logger.error(f"Error in API text generation with {model_id}: {str(e)}")
134
+ return f"Error: {str(e)}"
135
 
136
+ @spaces.GPU
137
+ def classify_text_local(self, model_path, text):
138
+ """ローカルモデルでのテキスト分類"""
139
+ try:
140
+ logger.info(f"Running local classification with {model_path}")
141
+ result = self.pipelines[model_path](text)
142
+ return str(result)
143
+ except Exception as e:
144
+ logger.error(f"Error in local classification with {model_path}: {str(e)}")
145
+ return f"Error: {str(e)}"
146
 
147
+ def classify_text_api(self, model_id, text):
148
+ """API経由でのテキスト分類"""
149
+ try:
150
+ logger.info(f"Running API classification with {model_id}")
151
+ response = self.api_clients[model_id].text_classification(text)
152
+ return str(response)
153
+ except Exception as e:
154
+ logger.error(f"Error in API classification with {model_id}: {str(e)}")
155
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ def run_models(self, text, selected_types):
158
+ """選択されたタイプのモデルで分析を実行"""
159
+ results = []
 
 
 
 
 
160
 
161
+ # テキスト生成モデルの実行
162
+ for model in TEXT_GENERATION_MODELS:
163
+ if model["type"] in selected_types:
164
+ if model["type"] == "local":
165
+ result = self.generate_text_local(model["model_path"], text)
166
+ else: # api
167
+ result = self.generate_text_api(model["model_id"], text)
168
+ results.append(f"{model['name']}: {result}")
169
 
170
+ # 分類モデルの実行
171
+ for model in CLASSIFICATION_MODELS:
172
+ if model["type"] in selected_types:
173
+ if model["type"] == "local":
174
+ result = self.classify_text_local(model["model_path"], text)
175
+ else: # api
176
+ result = self.classify_text_api(model["model_id"], text)
177
+ results.append(f"{model['name']}: {result}")
 
178
 
179
+ # 結果リストの長さを調整
180
+ while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
181
+ results.append("")
 
 
 
 
182
 
183
+ return results
184
+
185
+ class UIManager:
186
+ def __init__(self, model_manager):
187
+ self.model_manager = model_manager
188
 
189
+ def create_ui(self):
190
+ """UIの作成"""
191
+ with gr.Blocks() as demo:
192
+ # ヘッダー
193
+ gr.Markdown("""
194
+ # Toxic Eye (Class-based Version)
195
+ This system evaluates the toxicity level of input text using both local models and Inference API.
196
+ """)
197
+
198
+ # 入力セクション
199
+ with gr.Row():
200
+ input_text = gr.Textbox(
201
+ label="Input Text",
202
+ placeholder="Enter text to analyze...",
203
+ lines=3
204
+ )
205
+
206
+ # フィルターセクション
207
+ with gr.Row():
208
+ filter_checkboxes = gr.CheckboxGroup(
209
+ choices=["local", "api"],
210
+ value=["local", "api"],
211
+ label="Filter Models",
212
+ info="Choose which types of models to use",
213
+ interactive=True
214
+ )
215
+
216
+ # 実行ボタン
217
+ with gr.Row():
218
+ invoke_button = gr.Button(
219
+ "Analyze Text",
220
+ variant="primary",
221
+ size="lg"
222
+ )
223
 
224
+ # モデル出力表示エリア
225
+ all_outputs = []
226
+
227
+ with gr.Tabs():
228
+ # テキスト生成モデルのタブ
229
+ with gr.Tab("Text Generation Models"):
230
+ for model in TEXT_GENERATION_MODELS:
231
+ with gr.Group():
232
+ gr.Markdown(f"### {model['name']} ({model['type']})")
233
+ output = gr.Textbox(
234
+ label=f"{model['name']} Output",
235
+ lines=5,
236
+ interactive=False,
237
+ info=model["description"]
238
+ )
239
+ all_outputs.append(output)
240
+
241
+ # 分類モデルのタブ
242
+ with gr.Tab("Classification Models"):
243
+ for model in CLASSIFICATION_MODELS:
244
+ with gr.Group():
245
+ gr.Markdown(f"### {model['name']} ({model['type']})")
246
+ output = gr.Textbox(
247
+ label=f"{model['name']} Output",
248
+ lines=5,
249
+ interactive=False,
250
+ info=model["description"]
251
+ )
252
+ all_outputs.append(output)
253
+
254
+ # イベント接続
255
+ invoke_button.click(
256
+ fn=self.handle_invoke,
257
+ inputs=[input_text, filter_checkboxes],
258
+ outputs=all_outputs
259
+ )
260
 
261
+ return demo
 
 
 
 
 
262
 
263
+ def handle_invoke(self, text, selected_types):
264
+ """モデル実行をハンドリング"""
265
+ return self.model_manager.run_models(text, selected_types)
266
 
267
+ class ToxicityApp:
268
+ def __init__(self):
269
+ self.model_manager = ModelManager()
270
+ self.ui_manager = UIManager(self.model_manager)
271
 
272
+ def run(self):
273
+ """アプリを起動"""
274
+ demo = self.ui_manager.create_ui()
275
+ demo.launch()
276
+
277
+ def main():
278
+ app = ToxicityApp()
279
+ app.run()
280
 
281
  if __name__ == "__main__":
282
  main()