JanviMl commited on
Commit
af98023
·
verified ·
1 Parent(s): 5289046

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +24 -10
model_loader.py CHANGED
@@ -1,19 +1,33 @@
 
1
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
2
 
3
- def load_model_and_tokenizer():
4
  """
5
- Load the fine-tuned XLM-RoBERTa model and tokenizer.
6
- Returns the model and tokenizer for use in classification.
7
  """
8
  try:
9
- model_name = "JanviMl/xlm-roberta-toxic-classifier-capstone" # Replace with your model repo ID
10
- # If the model is local: model_name = "./model"
11
-
12
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # Use slow tokenizer
14
  return model, tokenizer
15
  except Exception as e:
16
- raise Exception(f"Error loading model or tokenizer: {str(e)}")
17
 
18
- # Load the model and tokenizer once at startup
19
- model, tokenizer = load_model_and_tokenizer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_loader.py
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ from transformers import AutoModelForCausalLM
4
 
5
+ def load_classifier_model_and_tokenizer():
6
  """
7
+ Load the fine-tuned XLM-RoBERTa model and tokenizer for toxic comment classification.
8
+ Returns the model and tokenizer.
9
  """
10
  try:
11
+ model_name = "JanviMl/xlm-roberta-toxic-classifier-capstone"
 
 
12
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
14
  return model, tokenizer
15
  except Exception as e:
16
+ raise Exception(f"Error loading classifier model or tokenizer: {str(e)}")
17
 
18
+ def load_paraphrase_model_and_tokenizer():
19
+ """
20
+ Load the Granite 3.2-2B-Instruct model and tokenizer for paraphrasing.
21
+ Returns the model and tokenizer.
22
+ """
23
+ try:
24
+ model_name = "ibm-granite/granite-3.2-2b-instruct"
25
+ model = AutoModelForCausalLM.from_pretrained(model_name)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ return model, tokenizer
28
+ except Exception as e:
29
+ raise Exception(f"Error loading paraphrase model or tokenizer: {str(e)}")
30
+
31
+ # Load both models and tokenizers at startup
32
+ classifier_model, classifier_tokenizer = load_classifier_model_and_tokenizer()
33
+ paraphrase_model, paraphrase_tokenizer = load_paraphrase_model_and_tokenizer()