Spaces:
Runtime error
Runtime error
import gradio as gr | |
from sdc_classifier import SDCClassifier | |
from dotenv import load_dotenv | |
import torch | |
import json | |
import os | |
# Load environment variables | |
load_dotenv() | |
def initialize_environment(): | |
"""Ініціалізація середовища при першому запуску""" | |
DEFAULT_CLASSES_FILE = "classes.json" | |
DEFAULT_SIGNATURES_FILE = "signatures.npz" | |
CACHE_FILE = "embeddings_cache.db" | |
initial_info = { | |
"status": "initializing", | |
"model_info": {}, | |
"classes_info": {}, | |
"errors": [] | |
} | |
# Перевіряємо наявність необхідних файлів | |
if not os.path.exists(DEFAULT_CLASSES_FILE): | |
initial_info["errors"].append(f"ПОМИЛКА: Файл {DEFAULT_CLASSES_FILE} не знайдено!") | |
initial_info["status"] = "error" | |
return initial_info | |
# Створюємо класифікатор та завантажуємо класи | |
try: | |
classifier = SDCClassifier() | |
classes = classifier.load_classes(DEFAULT_CLASSES_FILE) | |
# Збираємо інформацію про класи | |
initial_info["classes_info"] = { | |
"total_classes": len(classes), | |
"classes_list": list(classes.keys()), | |
"hints_per_class": {cls: len(hints) for cls, hints in classes.items()} | |
} | |
# Якщо signatures не існують, створюємо нові | |
if not os.path.exists(DEFAULT_SIGNATURES_FILE): | |
initial_info["status"] = "creating_signatures" | |
result = classifier.initialize_signatures( | |
force_rebuild=True, | |
signatures_file=DEFAULT_SIGNATURES_FILE | |
) | |
if isinstance(result, str) and "error" in result.lower(): | |
initial_info["errors"].append(result) | |
initial_info["status"] = "error" | |
return initial_info | |
# Завантажуємо інформацію про модель | |
classifier.save_model_info("model_info.json") | |
with open("model_info.json", "r") as f: | |
initial_info["model_info"] = json.load(f) | |
initial_info["status"] = "success" | |
return initial_info, classifier | |
except Exception as e: | |
initial_info["errors"].append(f"ПОМИЛКА при ініціалізації: {str(e)}") | |
initial_info["status"] = "error" | |
return initial_info, None | |
def create_classifier(model_type, openai_model=None, local_model=None, device=None): | |
""" | |
Створення класифікатора з відповідними параметрами | |
Args: | |
model_type: тип моделі ("OpenAI" або "Local") | |
openai_model: назва моделі OpenAI | |
local_model: шлях до локальної моделі | |
device: пристрій для локальної моделі | |
Returns: | |
SDCClassifier: налаштований класифікатор | |
""" | |
if model_type == "OpenAI": | |
return SDCClassifier() | |
else: | |
return SDCClassifier(local_model=local_model, device=device) | |
def main(): | |
# Константи файлів | |
DEFAULT_CLASSES_FILE = "classes.json" | |
DEFAULT_SIGNATURES_FILE = "signatures.npz" | |
CACHE_FILE = "embeddings_cache.db" | |
# Перевіряємо та ініціалізуємо середовище | |
init_result = initialize_environment() | |
if isinstance(init_result, tuple): | |
initial_info, classifier = init_result | |
else: | |
initial_info = init_result | |
print("Не вдалося ініціалізувати середовище") | |
return | |
with gr.Blocks() as demo: | |
gr.Markdown("# SDC Classifier") | |
# Додаємо інформаційний блок про модель та класи | |
with gr.Accordion("Інформація про систему", open=True): | |
system_info = gr.JSON( | |
value=initial_info, | |
label="Статус системи" | |
) | |
if initial_info["status"] == "success": | |
gr.Markdown(f""" | |
### Поточна конфігурація: | |
- Модель: {initial_info['model_info'].get('using_local', 'OpenAI')} | |
- Кількість класів: {initial_info['classes_info']['total_classes']} | |
- Класи: {', '.join(initial_info['classes_info']['classes_list'])} | |
""") | |
else: | |
gr.Markdown(f""" | |
### Помилки ініціалізації: | |
{chr(10).join('- ' + err for err in initial_info['errors'])} | |
""") | |
with gr.Tabs(): | |
# Вкладка 1: Single Text Testing | |
with gr.TabItem("Тестування одного тексту"): | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Введіть текст для аналізу", | |
lines=5, | |
placeholder="Введіть текст..." | |
) | |
threshold_slider = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.3, | |
step=0.05, | |
label="Поріг впевненості" | |
) | |
single_process_btn = gr.Button("Проаналізувати") | |
with gr.Column(): | |
result_text = gr.JSON(label="Результати аналізу") | |
# Налаштування моделі | |
with gr.Accordion("Налаштування моделі", open=False): | |
with gr.Row(): | |
model_type = gr.Radio( | |
choices=["OpenAI", "Local"], | |
value="OpenAI", | |
label="Тип моделі" | |
) | |
model_choice = gr.Dropdown( | |
choices=[ | |
"text-embedding-3-large", | |
"text-embedding-3-small" | |
], | |
value="text-embedding-3-large", | |
label="OpenAI model", | |
visible=True | |
) | |
local_model_path = gr.Textbox( | |
value="cambridgeltl/SapBERT-from-PubMedBERT-fulltext", | |
label="Шлях до локальної моделі", | |
visible=False | |
) | |
device_choice = gr.Radio( | |
choices=["cuda", "cpu"], | |
value="cuda" if torch.cuda.is_available() else "cpu", | |
label="Пристрій для локальної моделі", | |
visible=False | |
) | |
with gr.Row(): | |
json_file = gr.File( | |
label="Завантажити новий JSON з класами", | |
file_types=[".json"] | |
) | |
force_rebuild = gr.Checkbox( | |
label="Примусово перебудувати signatures", | |
value=False | |
) | |
with gr.Row(): | |
build_btn = gr.Button("Оновити signatures") | |
build_out = gr.Label(label="Статус signatures") | |
cache_stats = gr.JSON(label="Статистика кешу", value={}) | |
# Вкладка 2: Batch Processing | |
with gr.TabItem("Пакетна обробка"): | |
gr.Markdown("## 1) Завантаження даних") | |
with gr.Row(): | |
csv_input = gr.Textbox( | |
value="messages.csv", | |
label="CSV-файл" | |
) | |
emb_input = gr.Textbox( | |
value="embeddings.npy", | |
label="Numpy Embeddings" | |
) | |
load_btn = gr.Button("Завантажити дані") | |
load_output = gr.Label(label="Результат завантаження") | |
gr.Markdown("## 2) Класифікація") | |
with gr.Row(): | |
filter_in = gr.Textbox(label="Фільтр (опціонально)") | |
batch_threshold = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.3, | |
step=0.05, | |
label="Поріг впевненості" | |
) | |
classify_btn = gr.Button("Класифікувати") | |
classify_out = gr.Dataframe(label="Результат (Message / Target / Scores)") | |
gr.Markdown("## 3) Зберегти результати") | |
save_btn = gr.Button("Зберегти розмічені дані") | |
save_out = gr.Label() | |
gr.Markdown(""" | |
### Інструкція: | |
1. У вкладці "Налаштування моделі" можна: | |
- Вибрати тип моделі (OpenAI або Local) | |
- Налаштувати параметри вибраної моделі | |
- Завантажити новий JSON файл з класами | |
- Примусово перебудувати signatures | |
2. Після зміни налаштувань натисніть "Оновити signatures" | |
3. Використовуйте повзунок "Поріг впевненості" для фільтрації результатів | |
4. На вкладці "Пакетна обробка" можна аналізувати багато повідомлень | |
5. Результати можна зберегти в CSV файл | |
""") | |
# Підключення обробників подій | |
def update_model_inputs(model_type): | |
"""Оновлення видимості полів в залежності від типу моделі""" | |
return { | |
model_choice: gr.update(visible=model_type == "OpenAI"), | |
local_model_path: gr.update(visible=model_type == "Local"), | |
device_choice: gr.update(visible=model_type == "Local") | |
} | |
def update_classifier_settings(json_file, model_type, openai_model, | |
local_model, device, force_rebuild): | |
"""Оновлення налаштувань класифікатора""" | |
try: | |
# Створюємо новий класифікатор з вибраними параметрами | |
nonlocal classifier | |
classifier = create_classifier( | |
model_type=model_type, | |
openai_model=openai_model if model_type == "OpenAI" else None, | |
local_model=local_model if model_type == "Local" else None, | |
device=device if model_type == "Local" else None | |
) | |
# Завантажуємо класи | |
if json_file is not None: | |
with open(json_file.name, 'r', encoding='utf-8') as f: | |
new_classes = json.load(f) | |
classifier.load_classes(new_classes) | |
else: | |
classifier.restore_base_state() | |
# Ініціалізуємо signatures | |
result = classifier.initialize_signatures( | |
force_rebuild=force_rebuild, | |
signatures_file=DEFAULT_SIGNATURES_FILE if not force_rebuild else None | |
) | |
# Оновлюємо інформацію про систему | |
classifier.save_model_info("model_info.json") | |
with open("model_info.json", "r") as f: | |
model_info = json.load(f) | |
system_info.update(value={ | |
"status": "success", | |
"model_info": model_info, | |
"classes_info": { | |
"total_classes": len(classifier.classes_json), | |
"classes_list": list(classifier.classes_json.keys()), | |
"hints_per_class": {cls: len(hints) | |
for cls, hints in classifier.classes_json.items()} | |
}, | |
"errors": [] | |
}) | |
return result, classifier.get_cache_stats() | |
except Exception as e: | |
return f"Помилка: {str(e)}", classifier.get_cache_stats() | |
def process_single_text(text, threshold): | |
"""Обробка одного тексту""" | |
try: | |
return classifier.process_single_text(text, threshold) | |
except Exception as e: | |
return {"error": str(e)} | |
def load_data(csv_path, emb_path): | |
"""Завантаження даних для пакетної обробки""" | |
try: | |
return classifier.load_data(csv_path, emb_path) | |
except Exception as e: | |
return f"Помилка: {str(e)}" | |
def classify_batch(filter_str, threshold): | |
"""Пакетна класифікація""" | |
try: | |
return classifier.classify_rows(filter_str, threshold) | |
except Exception as e: | |
return None | |
def save_results(): | |
"""Збереження результатів""" | |
try: | |
return classifier.save_results() | |
except Exception as e: | |
return f"Помилка: {str(e)}" | |
# Підключення подій | |
model_type.change( | |
fn=update_model_inputs, | |
inputs=[model_type], | |
outputs=[model_choice, local_model_path, device_choice] | |
) | |
build_btn.click( | |
fn=update_classifier_settings, | |
inputs=[ | |
json_file, | |
model_type, | |
model_choice, | |
local_model_path, | |
device_choice, | |
force_rebuild | |
], | |
outputs=[build_out, cache_stats] | |
) | |
single_process_btn.click( | |
fn=process_single_text, | |
inputs=[text_input, threshold_slider], | |
outputs=result_text | |
) | |
load_btn.click( | |
fn=load_data, | |
inputs=[csv_input, emb_input], | |
outputs=load_output | |
) | |
classify_btn.click( | |
fn=classify_batch, | |
inputs=[filter_in, batch_threshold], | |
outputs=classify_out | |
) | |
save_btn.click( | |
fn=save_results, | |
inputs=[], | |
outputs=save_out | |
) | |
# Запуск веб-інтерфейсу | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
if __name__ == "__main__": | |
main() |