pepegiallo's picture
Update app.py
dbe21a2 verified
raw
history blame contribute delete
2.48 kB
import gradio as gr
import torch
import re
from transformers import AutoTokenizer, T5EncoderModel
import torch.nn as nn
# Klassendefinition aus dem Training
class FlanT5Classifier(nn.Module):
def __init__(self, base_model_name="google/flan-t5-base", num_labels=4):
super().__init__()
self.encoder = T5EncoderModel.from_pretrained(base_model_name)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(self.encoder.config.d_model, num_labels)
def forward(self, input_ids, attention_mask=None):
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
pooled = encoder_outputs.last_hidden_state[:, 0]
logits = self.classifier(self.dropout(pooled))
return {"logits": logits}
# Tokenizer laden
tokenizer = AutoTokenizer.from_pretrained("pepegiallo/flan-t5-base_ner")
# Modell instanziieren und Token-Embeddings anpassen
model = FlanT5Classifier()
model.encoder.resize_token_embeddings(len(tokenizer))
# Gewichte laden
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
# ID-Zuordnung
id2label = {0: "LOC", 1: "ORG", 2: "PER", 3: "O"}
# Tokenizer-Funktionen
def custom_tokenize(text):
return re.findall(r"\w+|[^\w\s]", text, re.UNICODE)
def custom_detokenize(tokens):
text = ""
for i, token in enumerate(tokens):
if i > 0 and re.match(r"\w", token):
text += " "
text += token
return text
# Klassifikationsfunktion
def classify_tokens(text):
tokens = custom_tokenize(text)
results = []
for i in range(len(tokens)):
wrapped = tokens[:i] + ["<TSTART>", tokens[i], "<TEND>"] + tokens[i+1:]
prompt = "classify token in: " + custom_detokenize(wrapped)
inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
with torch.no_grad():
logits = model(**inputs)["logits"]
pred_id = torch.argmax(logits, dim=-1).item()
label = id2label[pred_id]
results.append((tokens[i], label))
return results
# Gradio UI
demo = gr.Interface(
fn=classify_tokens,
inputs=gr.Textbox(lines=3, placeholder="Enter a sentence..."),
outputs=gr.HighlightedText(label="Token Classification Output"),
title="Flan-T5 Token Classification (NER)",
description="Classifies each token in the input text as LOC, ORG, PER, or O."
)
demo.launch()