Spaces:
Sleeping
Sleeping
zhenyundeng
commited on
Commit
·
287a9db
1
Parent(s):
17775ce
update app.py
Browse files
app.py
CHANGED
@@ -85,6 +85,9 @@ justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', a
|
|
85 |
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
86 |
best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
87 |
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
|
|
|
|
|
|
88 |
# ---------------------------------------------------------------------------
|
89 |
|
90 |
# ----------------------------------------------------------------------------
|
@@ -280,6 +283,7 @@ def veracity_prediction(claim, evidence):
|
|
280 |
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
281 |
example_support = torch.argmax(
|
282 |
veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
|
|
|
283 |
|
284 |
has_unanswerable = False
|
285 |
has_true = False
|
@@ -342,6 +346,7 @@ def justification_generation(claim, evidence, verdict_label):
|
|
342 |
claim_str = extract_claim_str(claim, evidence, verdict_label)
|
343 |
claim_str.strip()
|
344 |
pred_justification = justification_model.generate(claim_str, device=device)
|
|
|
345 |
|
346 |
return pred_justification.strip()
|
347 |
|
@@ -363,8 +368,8 @@ def log_on_azure(file, logs, azure_share_client):
|
|
363 |
file_client.upload_file(logs)
|
364 |
|
365 |
|
|
|
366 |
@app.post("/predict/")
|
367 |
-
@spaces.GPU
|
368 |
def fact_checking(item: Item):
|
369 |
claim = item['claim']
|
370 |
source = item['source']
|
|
|
85 |
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
86 |
best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
87 |
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
88 |
+
|
89 |
+
print("veracity_model_device_0:{}".format(veracity_model.device))
|
90 |
+
print("justification_model_device_0:{}".format(justification_model.device))
|
91 |
# ---------------------------------------------------------------------------
|
92 |
|
93 |
# ----------------------------------------------------------------------------
|
|
|
283 |
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
284 |
example_support = torch.argmax(
|
285 |
veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
|
286 |
+
print("veracity_model_device_1:{}".format(veracity_model.device))
|
287 |
|
288 |
has_unanswerable = False
|
289 |
has_true = False
|
|
|
346 |
claim_str = extract_claim_str(claim, evidence, verdict_label)
|
347 |
claim_str.strip()
|
348 |
pred_justification = justification_model.generate(claim_str, device=device)
|
349 |
+
print("justification_model_device_1:{}".format(justification_model.device))
|
350 |
|
351 |
return pred_justification.strip()
|
352 |
|
|
|
368 |
file_client.upload_file(logs)
|
369 |
|
370 |
|
371 |
+
# @spaces.GPU
|
372 |
@app.post("/predict/")
|
|
|
373 |
def fact_checking(item: Item):
|
374 |
claim = item['claim']
|
375 |
source = item['source']
|