File size: 2,685 Bytes
6371026 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import gradio as gr
from transformers import pipeline, BertTokenizer, BertForSequenceClassification
import os
import pickle
from dotenv import dotenv_values
import pandas as pd
from service_dops_api.dops_config import ServiceDopsConfig
from service_dops_api.dops_classifier import DopsClassifier
hf_token = dotenv_values('.env')['HF_TOKEN']
def categoriser_predict(input_text):
tokenizer = BertTokenizer.from_pretrained("warleagle/service_name_categorizer",
token=hf_token)
model = BertForSequenceClassification.from_pretrained('warleagle/service_name_categorizer',token=hf_token)
clf = pipeline("text-classification", model=model, tokenizer=tokenizer)
predictions = clf(input_text)
numeric_label = int(predictions[0]['label'].split("_")[1])
id2label = pd.read_pickle('id2label_service_categoriser.pickle')
text_label = id2label[numeric_label]
return text_label
def doctor_spec_predict(input_text):
tokenizer = BertTokenizer.from_pretrained("warleagle/specialists_categorizer_model",
token=hf_token)
model = BertForSequenceClassification.from_pretrained('warleagle/specialists_categorizer_model',token=hf_token)
clf = pipeline("text-classification", model=model, tokenizer=tokenizer)
predictions = clf(input_text)
numeric_label = int(predictions[0]['label'].split("_")[1])
id2label = pd.read_pickle('id2label_spec_categoriser.pickle')
text_label = id2label[numeric_label]
return text_label
def dops_predict(input_text):
cfg = ServiceDopsConfig()
model = DopsClassifier(config=cfg)
result = model.run_all_dops(input_text)
return result
def service_pipeline(input_text):
categoriser_result = categoriser_predict(input_text)
if categoriser_result!='Консультация специалиста':
return 'Эта услуга не относится к приему специалиста','-','-'
else:
doctor_spec_result = doctor_spec_predict(input_text)
dops_result = dops_predict(input_text)
return categoriser_result,doctor_spec_result,dops_result
demo = gr.Interface(fn=service_pipeline,inputs=gr.components.Textbox(label='Название услуги'),
outputs=[gr.components.Textbox(label='Относится ли данная услуга к приёму специалиста'),
gr.components.Textbox(label='Специальность врача'),
gr.components.Textbox(label='Дополнительные параметры услуги')])
if __name__ == "__main__":
demo.launch() |