pepegiallo commited on
Commit
dbe21a2
Β·
verified Β·
1 Parent(s): cc5f6e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -6
app.py CHANGED
@@ -1,16 +1,39 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import re
 
 
5
 
6
- # Modell laden
7
- model = AutoModelForSequenceClassification.from_pretrained("pepegiallo/flan-t5-base_ner")
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained("pepegiallo/flan-t5-base_ner")
 
 
 
 
 
 
 
 
9
  model.eval()
10
 
 
11
  id2label = {0: "LOC", 1: "ORG", 2: "PER", 3: "O"}
12
 
13
- # Hilfsfunktionen
14
  def custom_tokenize(text):
15
  return re.findall(r"\w+|[^\w\s]", text, re.UNICODE)
16
 
@@ -22,6 +45,7 @@ def custom_detokenize(tokens):
22
  text += token
23
  return text
24
 
 
25
  def classify_tokens(text):
26
  tokens = custom_tokenize(text)
27
  results = []
@@ -32,14 +56,14 @@ def classify_tokens(text):
32
 
33
  inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
34
  with torch.no_grad():
35
- logits = model(**inputs).logits
36
  pred_id = torch.argmax(logits, dim=-1).item()
37
  label = id2label[pred_id]
38
 
39
  results.append((tokens[i], label))
40
  return results
41
 
42
- # Gradio-UI definieren
43
  demo = gr.Interface(
44
  fn=classify_tokens,
45
  inputs=gr.Textbox(lines=3, placeholder="Enter a sentence..."),
 
1
  import gradio as gr
2
  import torch
 
3
  import re
4
+ from transformers import AutoTokenizer, T5EncoderModel
5
+ import torch.nn as nn
6
 
7
+ # Klassendefinition aus dem Training
8
+ class FlanT5Classifier(nn.Module):
9
+ def __init__(self, base_model_name="google/flan-t5-base", num_labels=4):
10
+ super().__init__()
11
+ self.encoder = T5EncoderModel.from_pretrained(base_model_name)
12
+ self.dropout = nn.Dropout(0.1)
13
+ self.classifier = nn.Linear(self.encoder.config.d_model, num_labels)
14
+
15
+ def forward(self, input_ids, attention_mask=None):
16
+ encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
17
+ pooled = encoder_outputs.last_hidden_state[:, 0]
18
+ logits = self.classifier(self.dropout(pooled))
19
+ return {"logits": logits}
20
+
21
+ # Tokenizer laden
22
  tokenizer = AutoTokenizer.from_pretrained("pepegiallo/flan-t5-base_ner")
23
+
24
+ # Modell instanziieren und Token-Embeddings anpassen
25
+ model = FlanT5Classifier()
26
+ model.encoder.resize_token_embeddings(len(tokenizer))
27
+
28
+ # Gewichte laden
29
+ state_dict = torch.load("pytorch_model.bin", map_location="cpu")
30
+ model.load_state_dict(state_dict)
31
  model.eval()
32
 
33
+ # ID-Zuordnung
34
  id2label = {0: "LOC", 1: "ORG", 2: "PER", 3: "O"}
35
 
36
+ # Tokenizer-Funktionen
37
  def custom_tokenize(text):
38
  return re.findall(r"\w+|[^\w\s]", text, re.UNICODE)
39
 
 
45
  text += token
46
  return text
47
 
48
+ # Klassifikationsfunktion
49
  def classify_tokens(text):
50
  tokens = custom_tokenize(text)
51
  results = []
 
56
 
57
  inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
58
  with torch.no_grad():
59
+ logits = model(**inputs)["logits"]
60
  pred_id = torch.argmax(logits, dim=-1).item()
61
  label = id2label[pred_id]
62
 
63
  results.append((tokens[i], label))
64
  return results
65
 
66
+ # Gradio UI
67
  demo = gr.Interface(
68
  fn=classify_tokens,
69
  inputs=gr.Textbox(lines=3, placeholder="Enter a sentence..."),