pepegiallo commited on
Commit
59e0ff2
Β·
verified Β·
1 Parent(s): e089b92

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import re
5
+
6
+ # Modell laden
7
+ model = AutoModelForSequenceClassification.from_pretrained("pepegiallo/flan-t5-base_ner")
8
+ tokenizer = AutoTokenizer.from_pretrained("pepegiallo/flan-t5-base_ner")
9
+ model.eval()
10
+
11
+ id2label = {0: "LOC", 1: "ORG", 2: "PER", 3: "O"}
12
+
13
+ # Hilfsfunktionen
14
+ def custom_tokenize(text):
15
+ return re.findall(r"\w+|[^\w\s]", text, re.UNICODE)
16
+
17
+ def custom_detokenize(tokens):
18
+ text = ""
19
+ for i, token in enumerate(tokens):
20
+ if i > 0 and re.match(r"\w", token):
21
+ text += " "
22
+ text += token
23
+ return text
24
+
25
+ def classify_tokens(text):
26
+ tokens = custom_tokenize(text)
27
+ results = []
28
+
29
+ for i in range(len(tokens)):
30
+ wrapped = tokens[:i] + ["<TSTART>", tokens[i], "<TEND>"] + tokens[i+1:]
31
+ prompt = "classify token in: " + custom_detokenize(wrapped)
32
+
33
+ inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
34
+ with torch.no_grad():
35
+ logits = model(**inputs).logits
36
+ pred_id = torch.argmax(logits, dim=-1).item()
37
+ label = id2label[pred_id]
38
+
39
+ results.append((tokens[i], label))
40
+ return results
41
+
42
+ # Gradio-UI definieren
43
+ demo = gr.Interface(
44
+ fn=classify_tokens,
45
+ inputs=gr.Textbox(lines=3, placeholder="Enter a sentence..."),
46
+ outputs=gr.HighlightedText(label="Token Classification Output"),
47
+ title="Flan-T5 Token Classification (NER)",
48
+ description="Classifies each token in the input text as LOC, ORG, PER, or O."
49
+ )
50
+
51
+ demo.launch()