File size: 8,244 Bytes
da88570 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import logging
import streamlit as st
import pandas as pd
from src.configuration.config import SessionStateConfig
from src.nlp.playground.textsummarization import SumySummarizer
from src.nlp.playground.pipelines.title_extractor import TitleExtractor
from src.utils.helpers import normalize_data
from src.utils.markdown_processing.CustomMarkdownAnalyzer.MarkdownAnalyzer import MarkdownAnalyzer
from src.nlp.playground.llm import QwenLlmHandler
from src.nlp.playground.ner import GlinerHandler
from src.persistence.db import init_db
from src.nlp.playground.textclassification import ZeroShotClassifier, CategoryMode, CustomMode
entities_schema = [
'"title": str | None"',
'"organizer": str | None"',
'"startDate": str | None"',
'"endDate": str | None"',
'"startTime": str | None"',
'"endTime": str | None"',
'"admittanceTime": str | None"',
'"locationName": str | None"',
'"street": str | None"',
'"houseNumber": str | None"',
'"postalCode": str | None"',
'"city": str | None"',
'"price": list[float] | None"',
'"currency": str | None"',
'"priceFree": bool | None"',
'"ticketsRequired": bool | None"',
'"categories": list[str] | None"',
'"eventDescription": str | None"',
'"accesibilityInformation": str | None"',
'"keywords": list[str] | None"'
]
@st.cache_resource
def init_connection():
return init_db()
@st.cache_resource
def init_data():
return db.event_urls.find(filter={"class":"EventDetail", "final":True}, projection={"url":1,"base_url_id":1,"cleaned_html":1, "data":1})
def render_md(md):
st.subheader("Original Text:")
with st.container(border=True, height=400):
st.markdown(md)
def render_table(table_data):
st.subheader("Extrahierte Daten:")
df = pd.DataFrame(table_data)
st.table(df)
st.markdown("---")
def init_session_state(key, value):
if key not in st.session_state:
clear_st_cache()
st.session_state[key] = value
def clear_st_cache():
keys = list(st.session_state.keys())
for key in keys:
st.session_state.pop(key)
db = init_connection()
data = init_data()
with st.expander("Large Language Models"):
with st.form("Settings LLM"):
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
st.write("Welche Informationen sollen extrahiert werden?")
options = []
for ent in entities_schema:
option = st.checkbox(ent ,key=ent)
options.append(option)
submit_llm = st.form_submit_button("Start")
if submit_llm:
selected_entities = [entity for entity, selected in zip(entities_schema, options) if selected]
init_session_state(SessionStateConfig.QWEN_LLM_HANDLER, QwenLlmHandler())
qwen_llm_handler = st.session_state[SessionStateConfig.QWEN_LLM_HANDLER]
try:
for event in data:
extracted_data = qwen_llm_handler.extract_data(text=event["data"], entities= ", ".join(selected_entities))
table_data = [{"Key": key, "Value": value} for key, value in extracted_data.items()]
render_md(event["data"])
render_table(table_data)
count -= 1
if count == 0:
break
except Exception as e:
st.write(f"Es ist ein Fehler aufgetreten: {e}")
with st.expander("Named Entity Recognition"):
with st.form("Settings NER"):
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
label_input = st.text_input("Gebe die Labels der Entitäten getrennt durch Komma an.")
submit_ner = st.form_submit_button("Start")
if submit_ner:
init_session_state(SessionStateConfig.GLINER_HANDLER, GlinerHandler())
gliner_handler = st.session_state[SessionStateConfig.GLINER_HANDLER]
if label_input:
labels = label_input.split(",")
for event in data:
text = normalize_data(event["data"])
render_md(text)
extracted_data = gliner_handler.extract_entities(text, labels)
table_data = [{"Key": element["label"], "Value": element["text"] } for element in extracted_data]
render_table(table_data)
count -= 1
if count == 0:
break
with st.expander("Textclassification"):
with st.form("Settings TextClassification"):
mode = st.selectbox("Classification Mode", ["Categories", "Custom"])
custom_labels = st.text_input("(Nur bei Custom Mode) Gib die Klassen Labels ein, durch Komma getrennt.", placeholder="Theater,Oper,Film")
custom_hypothesis_template = st.text_input("(Nur bei Custom Mode) Gib das Template ein. {} ist dabei der Platzhalter für die Labels", placeholder="Die Art der Veranstaltung ist {}")
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
submit_textclass = st.form_submit_button("Start")
if submit_textclass:
init_session_state(SessionStateConfig.ZERO_SHOT_CLASSIFIER, ZeroShotClassifier())
classifier = st.session_state[SessionStateConfig.ZERO_SHOT_CLASSIFIER]
if mode == "Categories":
classifier_mode = CategoryMode()
elif custom_labels and custom_hypothesis_template:
classifier_mode = CustomMode(labels=custom_labels.split(","), hypothesis_template=custom_hypothesis_template)
for event in data:
text = normalize_data(event["data"])
predictions = classifier.classify(text, classifier_mode)
table_data = [{"Kategorie": p.label, "Score": p.score} for p in predictions]
render_md(text)
render_table(table_data)
count -= 1
if count == 0:
break
with st.expander("Titel Extraktion"):
with st.form("Settings TitleExtraction"):
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
submit_title_extr = st.form_submit_button("Start")
if submit_title_extr:
init_session_state("title_extractor", TitleExtractor())
title_extractor = st.session_state.title_extractor
for event in data:
text = normalize_data(event["data"])
prediction = title_extractor.extract_title(text)
try:
pred2 = title_extractor.extract_title_classy_classification(text)
except FileNotFoundError as e:
pred2 = "ERROR: Train Model before usage"
table_data = [{"Label": "Titel (ZeroShot)", "Value": prediction}, {"Label": "Titel (FewShot)", "Value": pred2}]
render_md(text)
render_table(table_data)
count -= 1
if count == 0:
break
with st.expander("Textsummarization"):
with st.form("Settings Textsummarization"):
count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
submit_textsummarization = st.form_submit_button("Start")
if submit_textsummarization:
init_session_state(SessionStateConfig.SUMY_SUMMARIZER, SumySummarizer())
sumy_summarizer = st.session_state[SessionStateConfig.SUMY_SUMMARIZER]
for event in data:
try:
md = normalize_data(event["data"])
md_analyzer = MarkdownAnalyzer(md).identify_all()["block_elements"]
md_analyzer = sorted(md_analyzer, key=lambda el: el.line)
text = "\n\n".join([el.text for el in md_analyzer])
sumy_summary = sumy_summarizer.summarize(text)
summary = []
for element in md_analyzer:
if any(sentence in element.markdown for sentence in sumy_summary):
summary.append(element.markdown)
render_md(md)
st.subheader("Extrahierte Daten:")
with st.container(border=True, height=400):
st.markdown("\n\n".join(summary))
except Exception as e:
st.error(f"Fehler:{e}")
logging.exception("message")
count -= 1
if count == 0:
break |