|
import logging |
|
import streamlit as st |
|
import pandas as pd |
|
|
|
from src.configuration.config import SessionStateConfig |
|
from src.nlp.playground.textsummarization import SumySummarizer |
|
from src.nlp.playground.pipelines.title_extractor import TitleExtractor |
|
from src.utils.helpers import normalize_data |
|
from src.utils.markdown_processing.CustomMarkdownAnalyzer.MarkdownAnalyzer import MarkdownAnalyzer |
|
from src.nlp.playground.llm import QwenLlmHandler |
|
from src.nlp.playground.ner import GlinerHandler |
|
from src.persistence.db import init_db |
|
from src.nlp.playground.textclassification import ZeroShotClassifier, CategoryMode, CustomMode |
|
|
|
entities_schema = [ |
|
'"title": str | None"', |
|
'"organizer": str | None"', |
|
'"startDate": str | None"', |
|
'"endDate": str | None"', |
|
'"startTime": str | None"', |
|
'"endTime": str | None"', |
|
'"admittanceTime": str | None"', |
|
'"locationName": str | None"', |
|
'"street": str | None"', |
|
'"houseNumber": str | None"', |
|
'"postalCode": str | None"', |
|
'"city": str | None"', |
|
'"price": list[float] | None"', |
|
'"currency": str | None"', |
|
'"priceFree": bool | None"', |
|
'"ticketsRequired": bool | None"', |
|
'"categories": list[str] | None"', |
|
'"eventDescription": str | None"', |
|
'"accesibilityInformation": str | None"', |
|
'"keywords": list[str] | None"' |
|
] |
|
|
|
|
|
@st.cache_resource |
|
def init_connection(): |
|
return init_db() |
|
|
|
@st.cache_resource |
|
def init_data(): |
|
return db.event_urls.find(filter={"class":"EventDetail", "final":True}, projection={"url":1,"base_url_id":1,"cleaned_html":1, "data":1}) |
|
|
|
def render_md(md): |
|
st.subheader("Original Text:") |
|
with st.container(border=True, height=400): |
|
st.markdown(md) |
|
|
|
def render_table(table_data): |
|
st.subheader("Extrahierte Daten:") |
|
df = pd.DataFrame(table_data) |
|
st.table(df) |
|
st.markdown("---") |
|
|
|
def init_session_state(key, value): |
|
if key not in st.session_state: |
|
clear_st_cache() |
|
st.session_state[key] = value |
|
|
|
def clear_st_cache(): |
|
keys = list(st.session_state.keys()) |
|
for key in keys: |
|
st.session_state.pop(key) |
|
|
|
|
|
db = init_connection() |
|
data = init_data() |
|
|
|
with st.expander("Large Language Models"): |
|
with st.form("Settings LLM"): |
|
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1) |
|
st.write("Welche Informationen sollen extrahiert werden?") |
|
options = [] |
|
for ent in entities_schema: |
|
option = st.checkbox(ent ,key=ent) |
|
options.append(option) |
|
submit_llm = st.form_submit_button("Start") |
|
|
|
if submit_llm: |
|
selected_entities = [entity for entity, selected in zip(entities_schema, options) if selected] |
|
init_session_state(SessionStateConfig.QWEN_LLM_HANDLER, QwenLlmHandler()) |
|
qwen_llm_handler = st.session_state[SessionStateConfig.QWEN_LLM_HANDLER] |
|
try: |
|
for event in data: |
|
extracted_data = qwen_llm_handler.extract_data(text=event["data"], entities= ", ".join(selected_entities)) |
|
table_data = [{"Key": key, "Value": value} for key, value in extracted_data.items()] |
|
|
|
render_md(event["data"]) |
|
render_table(table_data) |
|
|
|
count -= 1 |
|
if count == 0: |
|
break |
|
except Exception as e: |
|
st.write(f"Es ist ein Fehler aufgetreten: {e}") |
|
|
|
with st.expander("Named Entity Recognition"): |
|
with st.form("Settings NER"): |
|
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1) |
|
label_input = st.text_input("Gebe die Labels der Entitäten getrennt durch Komma an.") |
|
submit_ner = st.form_submit_button("Start") |
|
|
|
if submit_ner: |
|
init_session_state(SessionStateConfig.GLINER_HANDLER, GlinerHandler()) |
|
gliner_handler = st.session_state[SessionStateConfig.GLINER_HANDLER] |
|
if label_input: |
|
labels = label_input.split(",") |
|
for event in data: |
|
text = normalize_data(event["data"]) |
|
render_md(text) |
|
|
|
extracted_data = gliner_handler.extract_entities(text, labels) |
|
table_data = [{"Key": element["label"], "Value": element["text"] } for element in extracted_data] |
|
render_table(table_data) |
|
|
|
count -= 1 |
|
if count == 0: |
|
break |
|
|
|
with st.expander("Textclassification"): |
|
with st.form("Settings TextClassification"): |
|
mode = st.selectbox("Classification Mode", ["Categories", "Custom"]) |
|
custom_labels = st.text_input("(Nur bei Custom Mode) Gib die Klassen Labels ein, durch Komma getrennt.", placeholder="Theater,Oper,Film") |
|
custom_hypothesis_template = st.text_input("(Nur bei Custom Mode) Gib das Template ein. {} ist dabei der Platzhalter für die Labels", placeholder="Die Art der Veranstaltung ist {}") |
|
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1) |
|
submit_textclass = st.form_submit_button("Start") |
|
|
|
if submit_textclass: |
|
init_session_state(SessionStateConfig.ZERO_SHOT_CLASSIFIER, ZeroShotClassifier()) |
|
classifier = st.session_state[SessionStateConfig.ZERO_SHOT_CLASSIFIER] |
|
if mode == "Categories": |
|
classifier_mode = CategoryMode() |
|
elif custom_labels and custom_hypothesis_template: |
|
classifier_mode = CustomMode(labels=custom_labels.split(","), hypothesis_template=custom_hypothesis_template) |
|
for event in data: |
|
text = normalize_data(event["data"]) |
|
predictions = classifier.classify(text, classifier_mode) |
|
table_data = [{"Kategorie": p.label, "Score": p.score} for p in predictions] |
|
|
|
render_md(text) |
|
render_table(table_data) |
|
|
|
count -= 1 |
|
if count == 0: |
|
break |
|
|
|
with st.expander("Titel Extraktion"): |
|
with st.form("Settings TitleExtraction"): |
|
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1) |
|
submit_title_extr = st.form_submit_button("Start") |
|
|
|
if submit_title_extr: |
|
init_session_state("title_extractor", TitleExtractor()) |
|
title_extractor = st.session_state.title_extractor |
|
|
|
for event in data: |
|
text = normalize_data(event["data"]) |
|
prediction = title_extractor.extract_title(text) |
|
try: |
|
pred2 = title_extractor.extract_title_classy_classification(text) |
|
except FileNotFoundError as e: |
|
pred2 = "ERROR: Train Model before usage" |
|
table_data = [{"Label": "Titel (ZeroShot)", "Value": prediction}, {"Label": "Titel (FewShot)", "Value": pred2}] |
|
|
|
render_md(text) |
|
render_table(table_data) |
|
|
|
count -= 1 |
|
if count == 0: |
|
break |
|
|
|
with st.expander("Textsummarization"): |
|
with st.form("Settings Textsummarization"): |
|
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1) |
|
submit_textsummarization = st.form_submit_button("Start") |
|
|
|
if submit_textsummarization: |
|
init_session_state(SessionStateConfig.SUMY_SUMMARIZER, SumySummarizer()) |
|
sumy_summarizer = st.session_state[SessionStateConfig.SUMY_SUMMARIZER] |
|
for event in data: |
|
try: |
|
md = normalize_data(event["data"]) |
|
md_analyzer = MarkdownAnalyzer(md).identify_all()["block_elements"] |
|
md_analyzer = sorted(md_analyzer, key=lambda el: el.line) |
|
text = "\n\n".join([el.text for el in md_analyzer]) |
|
sumy_summary = sumy_summarizer.summarize(text) |
|
summary = [] |
|
for element in md_analyzer: |
|
if any(sentence in element.markdown for sentence in sumy_summary): |
|
summary.append(element.markdown) |
|
|
|
render_md(md) |
|
st.subheader("Extrahierte Daten:") |
|
with st.container(border=True, height=400): |
|
st.markdown("\n\n".join(summary)) |
|
|
|
except Exception as e: |
|
st.error(f"Fehler:{e}") |
|
logging.exception("message") |
|
count -= 1 |
|
if count == 0: |
|
break |