Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import re | |
from transformers import AutoTokenizer, T5EncoderModel | |
import torch.nn as nn | |
# Klassendefinition aus dem Training | |
class FlanT5Classifier(nn.Module): | |
def __init__(self, base_model_name="google/flan-t5-base", num_labels=4): | |
super().__init__() | |
self.encoder = T5EncoderModel.from_pretrained(base_model_name) | |
self.dropout = nn.Dropout(0.1) | |
self.classifier = nn.Linear(self.encoder.config.d_model, num_labels) | |
def forward(self, input_ids, attention_mask=None): | |
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
pooled = encoder_outputs.last_hidden_state[:, 0] | |
logits = self.classifier(self.dropout(pooled)) | |
return {"logits": logits} | |
# Tokenizer laden | |
tokenizer = AutoTokenizer.from_pretrained("pepegiallo/flan-t5-base_ner") | |
# Modell instanziieren und Token-Embeddings anpassen | |
model = FlanT5Classifier() | |
model.encoder.resize_token_embeddings(len(tokenizer)) | |
# Gewichte laden | |
state_dict = torch.load("pytorch_model.bin", map_location="cpu") | |
model.load_state_dict(state_dict) | |
model.eval() | |
# ID-Zuordnung | |
id2label = {0: "LOC", 1: "ORG", 2: "PER", 3: "O"} | |
# Tokenizer-Funktionen | |
def custom_tokenize(text): | |
return re.findall(r"\w+|[^\w\s]", text, re.UNICODE) | |
def custom_detokenize(tokens): | |
text = "" | |
for i, token in enumerate(tokens): | |
if i > 0 and re.match(r"\w", token): | |
text += " " | |
text += token | |
return text | |
# Klassifikationsfunktion | |
def classify_tokens(text): | |
tokens = custom_tokenize(text) | |
results = [] | |
for i in range(len(tokens)): | |
wrapped = tokens[:i] + ["<TSTART>", tokens[i], "<TEND>"] + tokens[i+1:] | |
prompt = "classify token in: " + custom_detokenize(wrapped) | |
inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=128) | |
with torch.no_grad(): | |
logits = model(**inputs)["logits"] | |
pred_id = torch.argmax(logits, dim=-1).item() | |
label = id2label[pred_id] | |
results.append((tokens[i], label)) | |
return results | |
# Gradio UI | |
demo = gr.Interface( | |
fn=classify_tokens, | |
inputs=gr.Textbox(lines=3, placeholder="Enter a sentence..."), | |
outputs=gr.HighlightedText(label="Token Classification Output"), | |
title="Flan-T5 Token Classification (NER)", | |
description="Classifies each token in the input text as LOC, ORG, PER, or O." | |
) | |
demo.launch() | |