manaviel85370
refactor infos
56abaa1
import logging
import streamlit as st
import pandas as pd
from src.configuration.config import SessionStateConfig
from src.nlp.playground.textsummarization import SumySummarizer
from src.nlp.playground.pipelines.title_extractor import TitleExtractor
from src.utils.helpers import normalize_data
from src.utils.markdown_processing.CustomMarkdownAnalyzer.MarkdownAnalyzer import MarkdownAnalyzer
from src.nlp.playground.llm import QwenLlmHandler
from src.nlp.playground.ner import GlinerHandler
from src.persistence.db import init_db
from src.nlp.playground.textclassification import ZeroShotClassifier, CategoryMode, CustomMode
entities_schema = [
'"title": str | None"',
'"organizer": str | None"',
'"startDate": str | None"',
'"endDate": str | None"',
'"startTime": str | None"',
'"endTime": str | None"',
'"admittanceTime": str | None"',
'"locationName": str | None"',
'"street": str | None"',
'"houseNumber": str | None"',
'"postalCode": str | None"',
'"city": str | None"',
'"price": list[float] | None"',
'"currency": str | None"',
'"priceFree": bool | None"',
'"ticketsRequired": bool | None"',
'"categories": list[str] | None"',
'"eventDescription": str | None"',
'"accesibilityInformation": str | None"',
'"keywords": list[str] | None"'
]
@st.cache_resource
def init_connection():
return init_db()
@st.cache_resource
def init_data():
return db.event_urls.find(filter={"class":"EventDetail", "final":True}, projection={"url":1,"base_url_id":1,"cleaned_html":1, "data":1})
def render_md(md):
st.subheader("Original Text:")
with st.container(border=True, height=400):
st.markdown(md)
def render_table(table_data):
st.subheader("Extrahierte Daten:")
df = pd.DataFrame(table_data)
st.table(df)
st.markdown("---")
def init_session_state(key, value):
if key not in st.session_state:
clear_st_cache()
st.session_state[key] = value
def clear_st_cache():
keys = list(st.session_state.keys())
for key in keys:
st.session_state.pop(key)
db = init_connection()
data = init_data()
with st.expander("Large Language Models"):
with st.form("Settings LLM"):
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
st.write("Welche Informationen sollen extrahiert werden?")
options = []
for ent in entities_schema:
option = st.checkbox(ent ,key=ent)
options.append(option)
submit_llm = st.form_submit_button("Start")
if submit_llm:
selected_entities = [entity for entity, selected in zip(entities_schema, options) if selected]
init_session_state(SessionStateConfig.QWEN_LLM_HANDLER, QwenLlmHandler())
qwen_llm_handler = st.session_state[SessionStateConfig.QWEN_LLM_HANDLER]
try:
for event in data:
extracted_data = qwen_llm_handler.extract_data(text=event["data"], entities= ", ".join(selected_entities))
table_data = [{"Key": key, "Value": value} for key, value in extracted_data.items()]
render_md(event["data"])
render_table(table_data)
count -= 1
if count == 0:
break
except Exception as e:
st.write(f"Es ist ein Fehler aufgetreten: {e}")
with st.expander("Named Entity Recognition"):
with st.form("Settings NER"):
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
label_input = st.text_input("Gebe die Labels der Entitäten getrennt durch Komma an.")
submit_ner = st.form_submit_button("Start")
if submit_ner:
init_session_state(SessionStateConfig.GLINER_HANDLER, GlinerHandler())
gliner_handler = st.session_state[SessionStateConfig.GLINER_HANDLER]
if label_input:
labels = label_input.split(",")
for event in data:
text = normalize_data(event["data"])
render_md(text)
extracted_data = gliner_handler.extract_entities(text, labels)
table_data = [{"Key": element["label"], "Value": element["text"] } for element in extracted_data]
render_table(table_data)
count -= 1
if count == 0:
break
with st.expander("Textclassification"):
with st.form("Settings TextClassification"):
mode = st.selectbox("Classification Mode", ["Categories", "Custom"])
custom_labels = st.text_input("(Nur bei Custom Mode) Gib die Klassen Labels ein, durch Komma getrennt.", placeholder="Theater,Oper,Film")
custom_hypothesis_template = st.text_input("(Nur bei Custom Mode) Gib das Template ein. {} ist dabei der Platzhalter für die Labels", placeholder="Die Art der Veranstaltung ist {}")
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
submit_textclass = st.form_submit_button("Start")
if submit_textclass:
init_session_state(SessionStateConfig.ZERO_SHOT_CLASSIFIER, ZeroShotClassifier())
classifier = st.session_state[SessionStateConfig.ZERO_SHOT_CLASSIFIER]
if mode == "Categories":
classifier_mode = CategoryMode()
elif custom_labels and custom_hypothesis_template:
classifier_mode = CustomMode(labels=custom_labels.split(","), hypothesis_template=custom_hypothesis_template)
for event in data:
text = normalize_data(event["data"])
predictions = classifier.classify(text, classifier_mode)
table_data = [{"Kategorie": p.label, "Score": p.score} for p in predictions]
render_md(text)
render_table(table_data)
count -= 1
if count == 0:
break
with st.expander("Titel Extraktion"):
with st.form("Settings TitleExtraction"):
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
submit_title_extr = st.form_submit_button("Start")
if submit_title_extr:
init_session_state("title_extractor", TitleExtractor())
title_extractor = st.session_state.title_extractor
for event in data:
text = normalize_data(event["data"])
prediction = title_extractor.extract_title(text)
try:
pred2 = title_extractor.extract_title_classy_classification(text)
except FileNotFoundError as e:
pred2 = "ERROR: Train Model before usage"
table_data = [{"Label": "Titel (ZeroShot)", "Value": prediction}, {"Label": "Titel (FewShot)", "Value": pred2}]
render_md(text)
render_table(table_data)
count -= 1
if count == 0:
break
with st.expander("Textsummarization"):
with st.form("Settings Textsummarization"):
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
submit_textsummarization = st.form_submit_button("Start")
if submit_textsummarization:
init_session_state(SessionStateConfig.SUMY_SUMMARIZER, SumySummarizer())
sumy_summarizer = st.session_state[SessionStateConfig.SUMY_SUMMARIZER]
for event in data:
try:
md = normalize_data(event["data"])
md_analyzer = MarkdownAnalyzer(md).identify_all()["block_elements"]
md_analyzer = sorted(md_analyzer, key=lambda el: el.line)
text = "\n\n".join([el.text for el in md_analyzer])
sumy_summary = sumy_summarizer.summarize(text)
summary = []
for element in md_analyzer:
if any(sentence in element.markdown for sentence in sumy_summary):
summary.append(element.markdown)
render_md(md)
st.subheader("Extrahierte Daten:")
with st.container(border=True, height=400):
st.markdown("\n\n".join(summary))
except Exception as e:
st.error(f"Fehler:{e}")
logging.exception("message")
count -= 1
if count == 0:
break