nyasukun commited on
Commit
03bbc94
·
verified ·
1 Parent(s): 08d5a0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -216
app.py CHANGED
@@ -43,240 +43,226 @@ CLASSIFICATION_MODELS = [
43
  }
44
  ]
45
 
46
- class ModelManager:
47
- def __init__(self):
48
- self.tokenizers = {}
49
- self.pipelines = {}
50
- self.api_clients = {}
51
- self._initialize_api_clients()
52
- self._preload_local_models()
53
 
54
- def _initialize_api_clients(self):
55
- """Inference APIクライアントの初期化"""
56
- for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
57
- if model["type"] == "api" and "model_id" in model:
58
- logger.info(f"Initializing API client for {model['name']}")
59
- self.api_clients[model["model_id"]] = InferenceClient(
60
- model["model_id"],
61
- token=True # HFトークンを使用
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
 
 
 
63
 
64
- def _preload_local_models(self):
65
- """ローカルモデルを事前ロード"""
66
- logger.info("Preloading local models at application startup...")
67
-
68
- # テキスト生成モデル
69
- for model in TEXT_GENERATION_MODELS:
70
- if model["type"] == "local" and "model_path" in model:
71
- model_path = model["model_path"]
72
- try:
73
- logger.info(f"Preloading text generation model: {model_path}")
74
- self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
75
- self.pipelines[model_path] = pipeline(
76
- "text-generation",
77
- model=model_path,
78
- tokenizer=self.tokenizers[model_path],
79
- torch_dtype=torch.bfloat16,
80
- trust_remote_code=True,
81
- device_map="auto"
82
- )
83
- logger.info(f"Model preloaded successfully: {model_path}")
84
- except Exception as e:
85
- logger.error(f"Error preloading model {model_path}: {str(e)}")
86
-
87
- # 分類モデル
88
- for model in CLASSIFICATION_MODELS:
89
- if model["type"] == "local" and "model_path" in model:
90
- model_path = model["model_path"]
91
- try:
92
- logger.info(f"Preloading classification model: {model_path}")
93
- self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
94
- self.pipelines[model_path] = pipeline(
95
- "text-classification",
96
- model=model_path,
97
- tokenizer=self.tokenizers[model_path],
98
- torch_dtype=torch.bfloat16,
99
- trust_remote_code=True,
100
- device_map="auto"
101
- )
102
- logger.info(f"Model preloaded successfully: {model_path}")
103
- except Exception as e:
104
- logger.error(f"Error preloading model {model_path}: {str(e)}")
105
 
106
- @spaces.GPU
107
- def generate_text_local(self, model_path, text):
108
- """ローカルモデルでのテキスト生成"""
109
- try:
110
- logger.info(f"Running local text generation with {model_path}")
111
- outputs = self.pipelines[model_path](
112
- text,
113
- max_new_tokens=100,
114
- do_sample=False,
115
- num_return_sequences=1
116
- )
117
- return outputs[0]["generated_text"]
118
- except Exception as e:
119
- logger.error(f"Error in local text generation with {model_path}: {str(e)}")
120
- return f"Error: {str(e)}"
121
 
122
- def generate_text_api(self, model_id, text):
123
- """API経由でのテキスト生成"""
124
- try:
125
- logger.info(f"Running API text generation with {model_id}")
126
- response = self.api_clients[model_id].text_generation(
127
- text,
128
- max_new_tokens=100,
129
- temperature=0.7
130
- )
131
- return response
132
- except Exception as e:
133
- logger.error(f"Error in API text generation with {model_id}: {str(e)}")
134
- return f"Error: {str(e)}"
135
 
136
- @spaces.GPU
137
- def classify_text_local(self, model_path, text):
138
- """ローカルモデルでのテキスト分類"""
139
- try:
140
- logger.info(f"Running local classification with {model_path}")
141
- result = self.pipelines[model_path](text)
142
- return str(result)
143
- except Exception as e:
144
- logger.error(f"Error in local classification with {model_path}: {str(e)}")
145
- return f"Error: {str(e)}"
146
 
147
- def classify_text_api(self, model_id, text):
148
- """API経由でのテキスト分類"""
149
- try:
150
- logger.info(f"Running API classification with {model_id}")
151
- response = self.api_clients[model_id].text_classification(text)
152
- return str(response)
153
- except Exception as e:
154
- logger.error(f"Error in API classification with {model_id}: {str(e)}")
155
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- def run_models(self, text, selected_types):
158
- """選択されたタイプのモデルで分析を実行"""
159
- results = []
 
 
 
 
 
160
 
161
- # テキスト生成モデルの実行
162
- for model in TEXT_GENERATION_MODELS:
163
- if model["type"] in selected_types:
164
- if model["type"] == "local":
165
- result = self.generate_text_local(model["model_path"], text)
166
- else: # api
167
- result = self.generate_text_api(model["model_id"], text)
168
- results.append(f"{model['name']}: {result}")
169
 
170
- # 分類モデルの実行
171
- for model in CLASSIFICATION_MODELS:
172
- if model["type"] in selected_types:
173
- if model["type"] == "local":
174
- result = self.classify_text_local(model["model_path"], text)
175
- else: # api
176
- result = self.classify_text_api(model["model_id"], text)
177
- results.append(f"{model['name']}: {result}")
 
178
 
179
- # 結果リストの長さを調整
180
- while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
181
- results.append("")
 
 
 
 
182
 
183
- return results
184
-
185
- class UIManager:
186
- def __init__(self, model_manager):
187
- self.model_manager = model_manager
188
 
189
- def create_ui(self):
190
- """UIの作成"""
191
- with gr.Blocks() as demo:
192
- # ヘッダー
193
- gr.Markdown("""
194
- # Toxic Eye (Class-based Version)
195
- This system evaluates the toxicity level of input text using both local models and Inference API.
196
- """)
197
-
198
- # 入力セクション
199
- with gr.Row():
200
- input_text = gr.Textbox(
201
- label="Input Text",
202
- placeholder="Enter text to analyze...",
203
- lines=3
204
- )
205
-
206
- # フィルターセクション
207
- with gr.Row():
208
- filter_checkboxes = gr.CheckboxGroup(
209
- choices=["local", "api"],
210
- value=["local", "api"],
211
- label="Filter Models",
212
- info="Choose which types of models to use",
213
- interactive=True
214
- )
215
-
216
- # 実行ボタン
217
- with gr.Row():
218
- invoke_button = gr.Button(
219
- "Analyze Text",
220
- variant="primary",
221
- size="lg"
222
- )
223
 
224
- # モデル出力表示エリア
225
- all_outputs = []
226
-
227
- with gr.Tabs():
228
- # テキスト生成モデルのタブ
229
- with gr.Tab("Text Generation Models"):
230
- for model in TEXT_GENERATION_MODELS:
231
- with gr.Group():
232
- gr.Markdown(f"### {model['name']} ({model['type']})")
233
- output = gr.Textbox(
234
- label=f"{model['name']} Output",
235
- lines=5,
236
- interactive=False,
237
- info=model["description"]
238
- )
239
- all_outputs.append(output)
240
-
241
- # 分類モデルのタブ
242
- with gr.Tab("Classification Models"):
243
- for model in CLASSIFICATION_MODELS:
244
- with gr.Group():
245
- gr.Markdown(f"### {model['name']} ({model['type']})")
246
- output = gr.Textbox(
247
- label=f"{model['name']} Output",
248
- lines=5,
249
- interactive=False,
250
- info=model["description"]
251
- )
252
- all_outputs.append(output)
253
-
254
- # イベント接続
255
- invoke_button.click(
256
- fn=self.handle_invoke,
257
- inputs=[input_text, filter_checkboxes],
258
- outputs=all_outputs
259
- )
260
 
261
- return demo
 
 
 
 
 
262
 
263
- def handle_invoke(self, text, selected_types):
264
- """モデル実行をハンドリング"""
265
- return self.model_manager.run_models(text, selected_types)
266
-
267
- class ToxicityApp:
268
- def __init__(self):
269
- self.model_manager = ModelManager()
270
- self.ui_manager = UIManager(self.model_manager)
271
-
272
- def run(self):
273
- """アプリを起動"""
274
- demo = self.ui_manager.create_ui()
275
- demo.launch()
276
 
277
  def main():
278
- app = ToxicityApp()
279
- app.run()
 
 
 
 
 
 
 
280
 
281
  if __name__ == "__main__":
282
  main()
 
43
  }
44
  ]
45
 
46
+ # グローバル変数でモデルとAPIクライアントを管理
47
+ tokenizers = {}
48
+ pipelines = {}
49
+ api_clients = {}
 
 
 
50
 
51
+ def initialize_api_clients():
52
+ """Inference APIクライアントの初期化"""
53
+ for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
54
+ if model["type"] == "api" and "model_id" in model:
55
+ logger.info(f"Initializing API client for {model['name']}")
56
+ api_clients[model["model_id"]] = InferenceClient(
57
+ model["model_id"],
58
+ token=True # HFトークンを使用
59
+ )
60
+
61
+ def preload_local_models():
62
+ """ローカルモデルを事前ロード"""
63
+ logger.info("Preloading local models at application startup...")
64
+
65
+ # テキスト生成モデル
66
+ for model in TEXT_GENERATION_MODELS:
67
+ if model["type"] == "local" and "model_path" in model:
68
+ model_path = model["model_path"]
69
+ try:
70
+ logger.info(f"Preloading text generation model: {model_path}")
71
+ tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
72
+ pipelines[model_path] = pipeline(
73
+ "text-generation",
74
+ model=model_path,
75
+ tokenizer=tokenizers[model_path],
76
+ torch_dtype=torch.bfloat16,
77
+ trust_remote_code=True,
78
+ device_map="auto"
79
+ )
80
+ logger.info(f"Model preloaded successfully: {model_path}")
81
+ except Exception as e:
82
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
83
+
84
+ # 分類モデル
85
+ for model in CLASSIFICATION_MODELS:
86
+ if model["type"] == "local" and "model_path" in model:
87
+ model_path = model["model_path"]
88
+ try:
89
+ logger.info(f"Preloading classification model: {model_path}")
90
+ tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
91
+ pipelines[model_path] = pipeline(
92
+ "text-classification",
93
+ model=model_path,
94
+ tokenizer=tokenizers[model_path],
95
+ torch_dtype=torch.bfloat16,
96
+ trust_remote_code=True,
97
+ device_map="auto"
98
  )
99
+ logger.info(f"Model preloaded successfully: {model_path}")
100
+ except Exception as e:
101
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
102
 
103
+ @spaces.GPU
104
+ def generate_text_local(model_path, text):
105
+ """ローカルモデルでのテキスト生成"""
106
+ try:
107
+ logger.info(f"Running local text generation with {model_path}")
108
+ outputs = pipelines[model_path](
109
+ text,
110
+ max_new_tokens=100,
111
+ do_sample=False,
112
+ num_return_sequences=1
113
+ )
114
+ return outputs[0]["generated_text"]
115
+ except Exception as e:
116
+ logger.error(f"Error in local text generation with {model_path}: {str(e)}")
117
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ def generate_text_api(model_id, text):
120
+ """API経由でのテキスト生成"""
121
+ try:
122
+ logger.info(f"Running API text generation with {model_id}")
123
+ response = api_clients[model_id].text_generation(
124
+ text,
125
+ max_new_tokens=100,
126
+ temperature=0.7
127
+ )
128
+ return response
129
+ except Exception as e:
130
+ logger.error(f"Error in API text generation with {model_id}: {str(e)}")
131
+ return f"Error: {str(e)}"
 
 
132
 
133
+ @spaces.GPU
134
+ def classify_text_local(model_path, text):
135
+ """ローカルモデルでのテキスト分類"""
136
+ try:
137
+ logger.info(f"Running local classification with {model_path}")
138
+ result = pipelines[model_path](text)
139
+ return str(result)
140
+ except Exception as e:
141
+ logger.error(f"Error in local classification with {model_path}: {str(e)}")
142
+ return f"Error: {str(e)}"
 
 
 
143
 
144
+ def classify_text_api(model_id, text):
145
+ """API経由でのテキスト分類"""
146
+ try:
147
+ logger.info(f"Running API classification with {model_id}")
148
+ response = api_clients[model_id].text_classification(text)
149
+ return str(response)
150
+ except Exception as e:
151
+ logger.error(f"Error in API classification with {model_id}: {str(e)}")
152
+ return f"Error: {str(e)}"
 
153
 
154
+ def handle_invoke(text, selected_types):
155
+ """選択されたタイプのモデルで分析を実行"""
156
+ results = []
157
+
158
+ # テキスト生成モデルの実行
159
+ for model in TEXT_GENERATION_MODELS:
160
+ if model["type"] in selected_types:
161
+ if model["type"] == "local":
162
+ result = generate_text_local(model["model_path"], text)
163
+ else: # api
164
+ result = generate_text_api(model["model_id"], text)
165
+ results.append(f"{model['name']}: {result}")
166
+
167
+ # 分類モデルの実行
168
+ for model in CLASSIFICATION_MODELS:
169
+ if model["type"] in selected_types:
170
+ if model["type"] == "local":
171
+ result = classify_text_local(model["model_path"], text)
172
+ else: # api
173
+ result = classify_text_api(model["model_id"], text)
174
+ results.append(f"{model['name']}: {result}")
175
+
176
+ # 結果リストの長さを調整
177
+ while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
178
+ results.append("")
179
+
180
+ return results
181
 
182
+ def create_ui():
183
+ """UIの作成"""
184
+ with gr.Blocks() as demo:
185
+ # ヘッダー
186
+ gr.Markdown("""
187
+ # Toxic Eye (Local + API Version)
188
+ This system evaluates the toxicity level of input text using both local models and Inference API.
189
+ """)
190
 
191
+ # 入力セクション
192
+ with gr.Row():
193
+ input_text = gr.Textbox(
194
+ label="Input Text",
195
+ placeholder="Enter text to analyze...",
196
+ lines=3
197
+ )
 
198
 
199
+ # フィルターセクション
200
+ with gr.Row():
201
+ filter_checkboxes = gr.CheckboxGroup(
202
+ choices=["local", "api"],
203
+ value=["local", "api"],
204
+ label="Filter Models",
205
+ info="Choose which types of models to use",
206
+ interactive=True
207
+ )
208
 
209
+ # 実行ボタン
210
+ with gr.Row():
211
+ invoke_button = gr.Button(
212
+ "Analyze Text",
213
+ variant="primary",
214
+ size="lg"
215
+ )
216
 
217
+ # モデル出力表示エリア
218
+ all_outputs = []
 
 
 
219
 
220
+ with gr.Tabs():
221
+ # テキスト生成モデルのタブ
222
+ with gr.Tab("Text Generation Models"):
223
+ for model in TEXT_GENERATION_MODELS:
224
+ with gr.Group():
225
+ gr.Markdown(f"### {model['name']} ({model['type']})")
226
+ output = gr.Textbox(
227
+ label=f"{model['name']} Output",
228
+ lines=5,
229
+ interactive=False,
230
+ info=model["description"]
231
+ )
232
+ all_outputs.append(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ # 分類モデルのタブ
235
+ with gr.Tab("Classification Models"):
236
+ for model in CLASSIFICATION_MODELS:
237
+ with gr.Group():
238
+ gr.Markdown(f"### {model['name']} ({model['type']})")
239
+ output = gr.Textbox(
240
+ label=f"{model['name']} Output",
241
+ lines=5,
242
+ interactive=False,
243
+ info=model["description"]
244
+ )
245
+ all_outputs.append(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ # イベント接続
248
+ invoke_button.click(
249
+ fn=handle_invoke,
250
+ inputs=[input_text, filter_checkboxes],
251
+ outputs=all_outputs
252
+ )
253
 
254
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  def main():
257
+ # APIクライアントの初期化
258
+ initialize_api_clients()
259
+
260
+ # ローカルモデルを事前ロード
261
+ preload_local_models()
262
+
263
+ # UIを作成して起動
264
+ demo = create_ui()
265
+ demo.launch()
266
 
267
  if __name__ == "__main__":
268
  main()