elmanavi's picture
refactor testing
14a5766
from src.configuration.config import SessionStateConfig
from src.nlp.playground.textclassification import ZeroShotClassifier, TitleMode
from src.utils.markdown_processing.CustomMarkdownAnalyzer.MarkdownAnalyzer import MarkdownAnalyzer
from src.utils.markdown_processing.CustomMarkdownAnalyzer.MarkdownElements import Header
import streamlit as st
import pickle
import joblib
from huggingface_hub import login, hf_hub_download
import os
from dotenv import load_dotenv
load_dotenv()
token = os.getenv("HUGGING_FACE_SPACES_TOKEN")
login(token=token)
class TitleExtractor:
def __init__(self):
if SessionStateConfig.ZERO_SHOT_CLASSIFIER not in st.session_state:
st.session_state[SessionStateConfig.ZERO_SHOT_CLASSIFIER] = ZeroShotClassifier()
self.classifier = st.session_state[SessionStateConfig.ZERO_SHOT_CLASSIFIER]
def extract_title(self, event_text):
analyzer = MarkdownAnalyzer(event_text)
identified_headers = analyzer.identify_headers()
headers = identified_headers["Header"] if identified_headers else analyzer.identify_emphasis()
if headers:
header_labels = []
for header in headers:
header_text = header.text.replace("\\.", ".")
predictions = self.classifier.classify(text=header_text, mode=TitleMode())
header_labels.append({
"text": header_text,
"label": predictions[0].label if predictions else "Unknown",
"level": header.level if isinstance(header, Header) else 1
})
return self.__find_title(header_labels)
else:
return ""
def __find_title(self, header_labels, event_title_label="Titel"):
print(header_labels)
if len(header_labels) == 1:
return header_labels[0]["text"]
if header_labels[0]["label"] == event_title_label:
return header_labels[0]["text"]
else:
lowest_level = header_labels[0]
for h in header_labels:
print(h["level"], h["text"])
if h["level"] < lowest_level["level"]:
lowest_level = h
print("lowest Level", lowest_level)
return lowest_level["text"]
def extract_title_classy_classification(self,event_text):
analyzer = MarkdownAnalyzer(event_text)
identified_headers = analyzer.identify_headers()
classifier = joblib.load(
hf_hub_download(repo_id="adojode/title_classifier", filename="title_classifier" + ".pkl")
)
headers = identified_headers["Header"] if identified_headers else analyzer.identify_emphasis()
if headers:
header_labels = []
for header in headers:
header_text = header.text.replace("\\.", ".")
predictions = classifier(header_text)
header_labels.append({
"text": header_text,
"label": max(predictions,key=predictions.get) if predictions else "Unknown",
"level": header.level if isinstance(header, Header) else 1
})
return self.__find_title(header_labels,"Veranstaltungstitel")
else:
return ""
# text = """### Veranstaltungen
#
# Finissage der Ausstellung Bartmann, Bier und Tafelzier.
# =======================================================
#
# Veranstaltungsort
#
# Finissage der Sonderausstellung Bartmann, Bier und Tafelzier. Steinzeug in der niederländischen Malerei
#
# © Museum August Kestner
#
# Diverse Krüge aus der Ausstellung
#
# Mit Kuratorinnenführung
#
# Termine
#
# 18.01.2025 ab 11:00 bis 18:00 Uhr
#
# Ort
#
# Museum August Kestner
# Platz der Menschenrechte 3
# 30159 Hannover"""
#
# print(TitleExtractor().extract_title(text))