zekun-li commited on
Commit
465e931
·
verified ·
1 Parent(s): 117e039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -6
app.py CHANGED
@@ -10,6 +10,29 @@ model.to("cpu") # Use "cuda" if you have GPU
10
  model.eval()
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def get_toponym_entities(text):
14
  inputs = tokenizer(
15
  text,
@@ -17,20 +40,39 @@ def get_toponym_entities(text):
17
  return_tensors="pt",
18
  truncation=True,
19
  max_length=512,
 
20
  )
21
  offset_mapping = inputs.pop("offset_mapping")[0]
22
- input_ids = inputs["input_ids"]
23
 
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
  predictions = torch.argmax(outputs.logits, dim=2)[0]
27
 
28
  entities = []
29
- for idx, label_id in enumerate(predictions):
30
- if label_id != 0 and idx < len(offset_mapping):
31
- start, end = offset_mapping[idx].tolist()
32
- if end > start:
33
- entities.append({"start": start, "end": end, "entity": "Topo"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  return {"text": text, "entities": entities}
36
 
 
10
  model.eval()
11
 
12
 
13
+ # def get_toponym_entities(text):
14
+ # inputs = tokenizer(
15
+ # text,
16
+ # return_offsets_mapping=True,
17
+ # return_tensors="pt",
18
+ # truncation=True,
19
+ # max_length=512,
20
+ # )
21
+ # offset_mapping = inputs.pop("offset_mapping")[0]
22
+ # input_ids = inputs["input_ids"]
23
+
24
+ # with torch.no_grad():
25
+ # outputs = model(**inputs)
26
+ # predictions = torch.argmax(outputs.logits, dim=2)[0]
27
+
28
+ # entities = []
29
+ # for idx, label_id in enumerate(predictions):
30
+ # if label_id != 0 and idx < len(offset_mapping):
31
+ # start, end = offset_mapping[idx].tolist()
32
+ # if end > start:
33
+ # entities.append({"start": start, "end": end, "entity": "Topo"})
34
+
35
+ # return {"text": text, "entities": entities}
36
  def get_toponym_entities(text):
37
  inputs = tokenizer(
38
  text,
 
40
  return_tensors="pt",
41
  truncation=True,
42
  max_length=512,
43
+ return_attention_mask=True,
44
  )
45
  offset_mapping = inputs.pop("offset_mapping")[0]
46
+ input_ids = inputs["input_ids"][0]
47
 
48
  with torch.no_grad():
49
  outputs = model(**inputs)
50
  predictions = torch.argmax(outputs.logits, dim=2)[0]
51
 
52
  entities = []
53
+ current_entity = None
54
+
55
+ for idx, (pred, offset) in enumerate(zip(predictions, offset_mapping)):
56
+ start, end = offset.tolist()
57
+ if start == end: # skip special tokens
58
+ continue
59
+
60
+ if pred != 0: # Non-O label
61
+ if current_entity is None:
62
+ current_entity = {"start": start, "end": end}
63
+ else:
64
+ # Extend the current entity span
65
+ current_entity["end"] = end
66
+ else:
67
+ if current_entity is not None:
68
+ current_entity["entity"] = "Topo"
69
+ entities.append(current_entity)
70
+ current_entity = None
71
+
72
+ # Catch any lingering entity at the end
73
+ if current_entity is not None:
74
+ current_entity["entity"] = "Topo"
75
+ entities.append(current_entity)
76
 
77
  return {"text": text, "entities": entities}
78