ApsidalSolid4 commited on
Commit
f99111f
·
verified ·
1 Parent(s): 9a92a21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -69,26 +69,29 @@ class TextClassifier:
69
  self.initialize_model()
70
 
71
  def initialize_model(self):
72
- """Initialize the model and tokenizer."""
73
- logger.info("Initializing model and tokenizer...")
74
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
75
-
76
- # First initialize the base model
77
- self.model = AutoModelForSequenceClassification.from_pretrained(
78
- self.model_name,
79
- num_labels=2
80
- ).to(self.device)
81
-
82
- # Look for model file in the same directory as the code
83
- model_path = "model.pt" # Your model file should be uploaded as model.pt
84
- if os.path.exists(model_path):
85
- logger.info(f"Loading custom model from {model_path}")
86
- checkpoint = torch.load(model_path, map_location=self.device)
87
- self.model.load_state_dict(checkpoint['model_state_dict'])
88
- else:
89
- logger.warning("Custom model file not found. Using base model.")
90
 
91
- self.model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def predict_with_sentence_scores(self, text: str) -> Dict:
94
  """Predict with sentence-level granularity using overlapping windows."""
 
69
  self.initialize_model()
70
 
71
  def initialize_model(self):
72
+ """Initialize the model and tokenizer."""
73
+ logger.info("Initializing model and tokenizer...")
74
+ self.tokenizer = AutoTokenizer.from_pretrained(
75
+ self.model_name,
76
+ use_fast=False # Use the slow tokenizer instead of the fast one
77
+ )
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # First initialize the base model
80
+ self.model = AutoModelForSequenceClassification.from_pretrained(
81
+ self.model_name,
82
+ num_labels=2
83
+ ).to(self.device)
84
+
85
+ # Look for model file in the same directory as the code
86
+ model_path = "model.pt" # Your model file should be uploaded as model.pt
87
+ if os.path.exists(model_path):
88
+ logger.info(f"Loading custom model from {model_path}")
89
+ checkpoint = torch.load(model_path, map_location=self.device)
90
+ self.model.load_state_dict(checkpoint['model_state_dict'])
91
+ else:
92
+ logger.warning("Custom model file not found. Using base model.")
93
+
94
+ self.model.eval()
95
 
96
  def predict_with_sentence_scores(self, text: str) -> Dict:
97
  """Predict with sentence-level granularity using overlapping windows."""