JanviMl commited on
Commit
c3a2cbd
·
verified ·
1 Parent(s): 9302051

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +66 -32
model_loader.py CHANGED
@@ -1,33 +1,67 @@
1
  # model_loader.py
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
- from transformers import AutoModelForCausalLM
4
-
5
- def load_classifier_model_and_tokenizer():
6
- """
7
- Load the fine-tuned XLM-RoBERTa model and tokenizer for toxic comment classification.
8
- Returns the model and tokenizer.
9
- """
10
- try:
11
- model_name = "JanviMl/xlm-roberta-toxic-classifier-capstone"
12
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
14
- return model, tokenizer
15
- except Exception as e:
16
- raise Exception(f"Error loading classifier model or tokenizer: {str(e)}")
17
-
18
- def load_paraphrase_model_and_tokenizer():
19
- """
20
- Load the Granite 3.2-2B-Instruct model and tokenizer for paraphrasing.
21
- Returns the model and tokenizer.
22
- """
23
- try:
24
- model_name = "ibm-granite/granite-3.2-2b-instruct"
25
- model = AutoModelForCausalLM.from_pretrained(model_name)
26
- tokenizer = AutoTokenizer.from_pretrained(model_name)
27
- return model, tokenizer
28
- except Exception as e:
29
- raise Exception(f"Error loading paraphrase model or tokenizer: {str(e)}")
30
-
31
- # Load both models and tokenizers at startup
32
- classifier_model, classifier_tokenizer = load_classifier_model_and_tokenizer()
33
- paraphrase_model, paraphrase_tokenizer = load_paraphrase_model_and_tokenizer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # model_loader.py
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
3
+ from sentence_transformers import SentenceTransformer
4
+ from transformers import pipeline
5
+
6
+ # Classifier Model (XLM-RoBERTa for toxicity classification)
7
+ class ClassifierModel:
8
+ def __init__(self):
9
+ self.model = None
10
+ self.tokenizer = None
11
+ self.load_model()
12
+
13
+ def load_model(self):
14
+ """
15
+ Load the fine-tuned XLM-RoBERTa model and tokenizer for toxic comment classification.
16
+ """
17
+ try:
18
+ model_name = "JanviMl/xlm-roberta-toxic-classifier-capstone"
19
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
20
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
21
+ except Exception as e:
22
+ raise Exception(f"Error loading classifier model or tokenizer: {str(e)}")
23
+
24
+ # Paraphraser Model (Granite 3.2-2B-Instruct for paraphrasing)
25
+ class ParaphraserModel:
26
+ def __init__(self):
27
+ self.model = None
28
+ self.tokenizer = None
29
+ self.load_model()
30
+
31
+ def load_model(self):
32
+ """
33
+ Load the Granite 3.2-2B-Instruct model and tokenizer for paraphrasing.
34
+ """
35
+ try:
36
+ model_name = "ibm-granite/granite-3.2-2b-instruct"
37
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
38
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+ except Exception as e:
40
+ raise Exception(f"Error loading paraphrase model or tokenizer: {str(e)}")
41
+
42
+ # Metrics Models (Sentence-BERT, Emotion Classifier, NLI)
43
+ class MetricsModels:
44
+ def __init__(self):
45
+ self.sentence_bert_model = None
46
+ self.emotion_classifier = None
47
+ self.nli_classifier = None
48
+
49
+ def load_sentence_bert(self):
50
+ if self.sentence_bert_model is None:
51
+ self.sentence_bert_model = SentenceTransformer('all-MiniLM-L6-v2')
52
+ return self.sentence_bert_model
53
+
54
+ def load_emotion_classifier(self):
55
+ if self.emotion_classifier is None:
56
+ self.emotion_classifier = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", top_k=None)
57
+ return self.emotion_classifier
58
+
59
+ def load_nli_classifier(self):
60
+ if self.nli_classifier is None:
61
+ self.nli_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
62
+ return self.nli_classifier
63
+
64
+ # Singleton instances
65
+ classifier_model = ClassifierModel()
66
+ paraphraser_model = ParaphraserModel()
67
+ metrics_models = MetricsModels()