|
from src.configuration.config import SessionStateConfig |
|
from src.nlp.playground.textclassification import ZeroShotClassifier, TitleMode |
|
from src.utils.markdown_processing.CustomMarkdownAnalyzer.MarkdownAnalyzer import MarkdownAnalyzer |
|
from src.utils.markdown_processing.CustomMarkdownAnalyzer.MarkdownElements import Header |
|
import streamlit as st |
|
import pickle |
|
import joblib |
|
from huggingface_hub import login, hf_hub_download |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
token = os.getenv("HUGGING_FACE_SPACES_TOKEN") |
|
login(token=token) |
|
|
|
class TitleExtractor: |
|
def __init__(self): |
|
if SessionStateConfig.ZERO_SHOT_CLASSIFIER not in st.session_state: |
|
st.session_state[SessionStateConfig.ZERO_SHOT_CLASSIFIER] = ZeroShotClassifier() |
|
self.classifier = st.session_state[SessionStateConfig.ZERO_SHOT_CLASSIFIER] |
|
|
|
def extract_title(self, event_text): |
|
analyzer = MarkdownAnalyzer(event_text) |
|
identified_headers = analyzer.identify_headers() |
|
|
|
headers = identified_headers["Header"] if identified_headers else analyzer.identify_emphasis() |
|
if headers: |
|
header_labels = [] |
|
for header in headers: |
|
header_text = header.text.replace("\\.", ".") |
|
predictions = self.classifier.classify(text=header_text, mode=TitleMode()) |
|
header_labels.append({ |
|
"text": header_text, |
|
"label": predictions[0].label if predictions else "Unknown", |
|
"level": header.level if isinstance(header, Header) else 1 |
|
}) |
|
return self.__find_title(header_labels) |
|
else: |
|
return "" |
|
|
|
def __find_title(self, header_labels, event_title_label="Titel"): |
|
print(header_labels) |
|
if len(header_labels) == 1: |
|
return header_labels[0]["text"] |
|
|
|
if header_labels[0]["label"] == event_title_label: |
|
return header_labels[0]["text"] |
|
else: |
|
lowest_level = header_labels[0] |
|
for h in header_labels: |
|
print(h["level"], h["text"]) |
|
if h["level"] < lowest_level["level"]: |
|
lowest_level = h |
|
print("lowest Level", lowest_level) |
|
return lowest_level["text"] |
|
|
|
|
|
def extract_title_classy_classification(self,event_text): |
|
analyzer = MarkdownAnalyzer(event_text) |
|
identified_headers = analyzer.identify_headers() |
|
|
|
classifier = joblib.load( |
|
hf_hub_download(repo_id="adojode/title_classifier", filename="title_classifier" + ".pkl") |
|
) |
|
|
|
headers = identified_headers["Header"] if identified_headers else analyzer.identify_emphasis() |
|
if headers: |
|
header_labels = [] |
|
for header in headers: |
|
header_text = header.text.replace("\\.", ".") |
|
predictions = classifier(header_text) |
|
header_labels.append({ |
|
"text": header_text, |
|
"label": max(predictions,key=predictions.get) if predictions else "Unknown", |
|
"level": header.level if isinstance(header, Header) else 1 |
|
}) |
|
return self.__find_title(header_labels,"Veranstaltungstitel") |
|
else: |
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|