zhenyundeng commited on
Commit
287a9db
·
1 Parent(s): 17775ce

update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -85,6 +85,9 @@ justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', a
85
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
86
  best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
87
  justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
 
 
 
88
  # ---------------------------------------------------------------------------
89
 
90
  # ----------------------------------------------------------------------------
@@ -280,6 +283,7 @@ def veracity_prediction(claim, evidence):
280
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
281
  example_support = torch.argmax(
282
  veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
 
283
 
284
  has_unanswerable = False
285
  has_true = False
@@ -342,6 +346,7 @@ def justification_generation(claim, evidence, verdict_label):
342
  claim_str = extract_claim_str(claim, evidence, verdict_label)
343
  claim_str.strip()
344
  pred_justification = justification_model.generate(claim_str, device=device)
 
345
 
346
  return pred_justification.strip()
347
 
@@ -363,8 +368,8 @@ def log_on_azure(file, logs, azure_share_client):
363
  file_client.upload_file(logs)
364
 
365
 
 
366
  @app.post("/predict/")
367
- @spaces.GPU
368
  def fact_checking(item: Item):
369
  claim = item['claim']
370
  source = item['source']
 
85
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
86
  best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
87
  justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
88
+
89
+ print("veracity_model_device_0:{}".format(veracity_model.device))
90
+ print("justification_model_device_0:{}".format(justification_model.device))
91
  # ---------------------------------------------------------------------------
92
 
93
  # ----------------------------------------------------------------------------
 
283
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
284
  example_support = torch.argmax(
285
  veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
286
+ print("veracity_model_device_1:{}".format(veracity_model.device))
287
 
288
  has_unanswerable = False
289
  has_true = False
 
346
  claim_str = extract_claim_str(claim, evidence, verdict_label)
347
  claim_str.strip()
348
  pred_justification = justification_model.generate(claim_str, device=device)
349
+ print("justification_model_device_1:{}".format(justification_model.device))
350
 
351
  return pred_justification.strip()
352
 
 
368
  file_client.upload_file(logs)
369
 
370
 
371
+ # @spaces.GPU
372
  @app.post("/predict/")
 
373
  def fact_checking(item: Item):
374
  claim = item['claim']
375
  source = item['source']