File size: 8,244 Bytes
da88570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import logging
import streamlit as st
import pandas as pd

from src.configuration.config import SessionStateConfig
from src.nlp.playground.textsummarization import SumySummarizer
from src.nlp.playground.pipelines.title_extractor import TitleExtractor
from src.utils.helpers import normalize_data
from src.utils.markdown_processing.CustomMarkdownAnalyzer.MarkdownAnalyzer import MarkdownAnalyzer
from src.nlp.playground.llm import QwenLlmHandler
from src.nlp.playground.ner import GlinerHandler
from src.persistence.db import init_db
from src.nlp.playground.textclassification import ZeroShotClassifier, CategoryMode, CustomMode

entities_schema = [
    '"title": str | None"',
    '"organizer": str | None"',
    '"startDate": str | None"',
    '"endDate": str | None"',
    '"startTime": str | None"',
    '"endTime": str | None"',
    '"admittanceTime": str | None"',
    '"locationName": str | None"',
    '"street": str | None"',
    '"houseNumber": str | None"',
    '"postalCode": str | None"',
    '"city": str | None"',
    '"price": list[float] | None"',
    '"currency": str | None"',
    '"priceFree": bool | None"',
    '"ticketsRequired": bool | None"',
    '"categories": list[str] | None"',
    '"eventDescription": str | None"',
    '"accesibilityInformation": str | None"',
    '"keywords": list[str] | None"'
]


@st.cache_resource
def init_connection():
    return init_db()

@st.cache_resource
def init_data():
    return db.event_urls.find(filter={"class":"EventDetail", "final":True}, projection={"url":1,"base_url_id":1,"cleaned_html":1, "data":1})

def render_md(md):
    st.subheader("Original Text:")
    with st.container(border=True, height=400):
        st.markdown(md)

def render_table(table_data):
    st.subheader("Extrahierte Daten:")
    df = pd.DataFrame(table_data)
    st.table(df)
    st.markdown("---")

def init_session_state(key, value):
    if key not in st.session_state:
        clear_st_cache()
        st.session_state[key] = value

def clear_st_cache():
    keys = list(st.session_state.keys())
    for key in keys:
        st.session_state.pop(key)


db = init_connection()
data = init_data()

with st.expander("Large Language Models"):
    with st.form("Settings LLM"):
        count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
        st.write("Welche Informationen sollen extrahiert werden?")
        options = []
        for ent in entities_schema:
            option = st.checkbox(ent ,key=ent)
            options.append(option)
        submit_llm = st.form_submit_button("Start")

    if submit_llm:
        selected_entities = [entity for entity, selected in zip(entities_schema, options) if selected]
        init_session_state(SessionStateConfig.QWEN_LLM_HANDLER, QwenLlmHandler())
        qwen_llm_handler = st.session_state[SessionStateConfig.QWEN_LLM_HANDLER]
        try:
            for event in data:
                extracted_data = qwen_llm_handler.extract_data(text=event["data"], entities= ", ".join(selected_entities))
                table_data = [{"Key": key, "Value": value} for key, value in extracted_data.items()]

                render_md(event["data"])
                render_table(table_data)

                count -= 1
                if count == 0:
                    break
        except Exception as e:
            st.write(f"Es ist ein Fehler aufgetreten: {e}")

with st.expander("Named Entity Recognition"):
    with st.form("Settings NER"):
        count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
        label_input = st.text_input("Gebe die Labels der Entitäten getrennt durch Komma an.")
        submit_ner = st.form_submit_button("Start")

    if submit_ner:
        init_session_state(SessionStateConfig.GLINER_HANDLER, GlinerHandler())
        gliner_handler = st.session_state[SessionStateConfig.GLINER_HANDLER]
        if label_input:
            labels = label_input.split(",")
        for event in data:
            text = normalize_data(event["data"])
            render_md(text)

            extracted_data = gliner_handler.extract_entities(text, labels)
            table_data = [{"Key": element["label"], "Value": element["text"] } for element in extracted_data]
            render_table(table_data)

            count -= 1
            if count == 0:
                break

with st.expander("Textclassification"):
    with st.form("Settings TextClassification"):
        mode = st.selectbox("Classification Mode", ["Categories", "Custom"])
        custom_labels = st.text_input("(Nur bei Custom Mode) Gib die Klassen Labels ein, durch Komma getrennt.", placeholder="Theater,Oper,Film")
        custom_hypothesis_template = st.text_input("(Nur bei Custom Mode) Gib das Template ein. {} ist dabei der Platzhalter für die Labels", placeholder="Die Art der Veranstaltung ist {}")
        count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
        submit_textclass = st.form_submit_button("Start")

    if submit_textclass:
        init_session_state(SessionStateConfig.ZERO_SHOT_CLASSIFIER, ZeroShotClassifier())
        classifier = st.session_state[SessionStateConfig.ZERO_SHOT_CLASSIFIER]
        if mode == "Categories":
            classifier_mode = CategoryMode()
        elif custom_labels and custom_hypothesis_template:
            classifier_mode = CustomMode(labels=custom_labels.split(","), hypothesis_template=custom_hypothesis_template)
        for event in data:
            text = normalize_data(event["data"])
            predictions = classifier.classify(text, classifier_mode)
            table_data = [{"Kategorie": p.label, "Score": p.score} for p in predictions]

            render_md(text)
            render_table(table_data)

            count -= 1
            if count == 0:
                break

with st.expander("Titel Extraktion"):
    with st.form("Settings TitleExtraction"):
        count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
        submit_title_extr = st.form_submit_button("Start")

    if submit_title_extr:
        init_session_state("title_extractor", TitleExtractor())
        title_extractor = st.session_state.title_extractor

        for event in data:
            text = normalize_data(event["data"])
            prediction = title_extractor.extract_title(text)
            try:
                pred2 = title_extractor.extract_title_classy_classification(text)
            except FileNotFoundError as e:
                pred2 = "ERROR: Train Model before usage"
            table_data = [{"Label": "Titel (ZeroShot)", "Value": prediction}, {"Label": "Titel (FewShot)", "Value": pred2}]

            render_md(text)
            render_table(table_data)

            count -= 1
            if count == 0:
                break

with st.expander("Textsummarization"):
    with st.form("Settings Textsummarization"):
        count = st.number_input("Wie viele Veranstaltungen sollen gestest werden?", step=1)
        submit_textsummarization = st.form_submit_button("Start")

    if submit_textsummarization:
        init_session_state(SessionStateConfig.SUMY_SUMMARIZER, SumySummarizer())
        sumy_summarizer = st.session_state[SessionStateConfig.SUMY_SUMMARIZER]
        for event in data:
            try:
                md = normalize_data(event["data"])
                md_analyzer = MarkdownAnalyzer(md).identify_all()["block_elements"]
                md_analyzer = sorted(md_analyzer, key=lambda el: el.line)
                text = "\n\n".join([el.text for el in md_analyzer])
                sumy_summary = sumy_summarizer.summarize(text)
                summary = []
                for element in md_analyzer:
                    if any(sentence in element.markdown for sentence in sumy_summary):
                        summary.append(element.markdown)

                render_md(md)
                st.subheader("Extrahierte Daten:")
                with st.container(border=True, height=400):
                    st.markdown("\n\n".join(summary))

            except Exception as e:
                st.error(f"Fehler:{e}")
                logging.exception("message")
            count -= 1
            if count == 0:
                break