Yassine commited on
Commit
7b036e8
·
1 Parent(s): 2be4d6b

Fix the allowed entities for each type

Browse files
Files changed (1) hide show
  1. main.py +12 -2
main.py CHANGED
@@ -33,6 +33,12 @@ base_dir = Path(__file__).parent.absolute()
33
  # Your Hugging Face Hub username
34
  HF_USERNAME = "YassineJedidi" # Replace with your actual username
35
 
 
 
 
 
 
 
36
  # Try to load models from Hugging Face Hub
37
  try:
38
  print("Loading models from Hugging Face Hub")
@@ -183,10 +189,14 @@ async def analyze_text(input_data: TextInput):
183
  type_result = await predict_type(input_data)
184
  text_type = type_result["type"]
185
  confidence = type_result["confidence"]
186
- entities = (await extract_entities(input_data))["entities"]
 
 
 
 
187
 
188
  return {
189
  "type": text_type,
190
  "confidence": confidence,
191
- "entities": entities
192
  }
 
33
  # Your Hugging Face Hub username
34
  HF_USERNAME = "YassineJedidi" # Replace with your actual username
35
 
36
+ # Définition des entités valides pour chaque type
37
+ entites_valides = {
38
+ "Tâche": {"TITRE", "DELAI", "PRIORITE"},
39
+ "Événement": {"TITRE", "DATE_HEURE"},
40
+ }
41
+
42
  # Try to load models from Hugging Face Hub
43
  try:
44
  print("Loading models from Hugging Face Hub")
 
189
  type_result = await predict_type(input_data)
190
  text_type = type_result["type"]
191
  confidence = type_result["confidence"]
192
+ raw_entities = (await extract_entities(input_data))["entities"]
193
+
194
+ # Filtrage des entités selon le type détecté
195
+ allowed = entites_valides.get(text_type, set())
196
+ filtered_entities = {k: v for k, v in raw_entities.items() if k in allowed}
197
 
198
  return {
199
  "type": text_type,
200
  "confidence": confidence,
201
+ "entities": filtered_entities
202
  }