import gradio as gr import torch import re from transformers import AutoTokenizer, T5EncoderModel import torch.nn as nn # Klassendefinition aus dem Training class FlanT5Classifier(nn.Module): def __init__(self, base_model_name="google/flan-t5-base", num_labels=4): super().__init__() self.encoder = T5EncoderModel.from_pretrained(base_model_name) self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(self.encoder.config.d_model, num_labels) def forward(self, input_ids, attention_mask=None): encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled = encoder_outputs.last_hidden_state[:, 0] logits = self.classifier(self.dropout(pooled)) return {"logits": logits} # Tokenizer laden tokenizer = AutoTokenizer.from_pretrained("pepegiallo/flan-t5-base_ner") # Modell instanziieren und Token-Embeddings anpassen model = FlanT5Classifier() model.encoder.resize_token_embeddings(len(tokenizer)) # Gewichte laden state_dict = torch.load("pytorch_model.bin", map_location="cpu") model.load_state_dict(state_dict) model.eval() # ID-Zuordnung id2label = {0: "LOC", 1: "ORG", 2: "PER", 3: "O"} # Tokenizer-Funktionen def custom_tokenize(text): return re.findall(r"\w+|[^\w\s]", text, re.UNICODE) def custom_detokenize(tokens): text = "" for i, token in enumerate(tokens): if i > 0 and re.match(r"\w", token): text += " " text += token return text # Klassifikationsfunktion def classify_tokens(text): tokens = custom_tokenize(text) results = [] for i in range(len(tokens)): wrapped = tokens[:i] + ["", tokens[i], ""] + tokens[i+1:] prompt = "classify token in: " + custom_detokenize(wrapped) inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=128) with torch.no_grad(): logits = model(**inputs)["logits"] pred_id = torch.argmax(logits, dim=-1).item() label = id2label[pred_id] results.append((tokens[i], label)) return results # Gradio UI demo = gr.Interface( fn=classify_tokens, inputs=gr.Textbox(lines=3, placeholder="Enter a sentence..."), outputs=gr.HighlightedText(label="Token Classification Output"), title="Flan-T5 Token Classification (NER)", description="Classifies each token in the input text as LOC, ORG, PER, or O." ) demo.launch()