Spaces:
Sleeping
Sleeping
import argparse | |
from datasets import load_dataset | |
from inference import SentimentInference | |
def run_sample_inference(config_path: str = "config.yaml", num_samples: int = 5): | |
""" | |
Loads a sentiment analysis model from a checkpoint, runs inference on a few | |
samples from the IMDB validation set, and prints the results. | |
""" | |
print("Loading sentiment model...") | |
# Initialize SentimentInference | |
# Ensure config_path points to your configuration file that specifies the model path | |
inferer = SentimentInference(config_path=config_path) | |
print("Model loaded.") | |
print("\nLoading IMDB dataset (test split for validation samples)...") | |
# Load the IMDB dataset, test split is used as validation | |
try: | |
imdb_dataset = load_dataset("imdb", split="test") | |
except Exception as e: | |
print(f"Failed to load IMDB dataset: {e}") | |
print("Please ensure you have an internet connection and the `datasets` library can access Hugging Face.") | |
print("You might need to run `pip install datasets` or check your network settings.") | |
return | |
print(f"Taking {num_samples} samples from the dataset.") | |
# Take a few samples | |
samples = imdb_dataset.shuffle().select(range(num_samples)) | |
print("\nRunning inference on selected samples:\n") | |
for i, sample in enumerate(samples): | |
text = sample["text"] | |
true_label_id = sample["label"] | |
true_label = "positive" if true_label_id == 1 else "negative" | |
print(f"--- Sample {i+1}/{num_samples} ---") | |
print(f"Text: {text[:200]}...") # Print first 200 chars for brevity | |
print(f"True Sentiment: {true_label}") | |
prediction = inferer.predict(text) | |
print(f"Predicted Sentiment: {prediction['sentiment']}") | |
print(f"Confidence: {prediction['confidence']:.4f}\n") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run sample inference on IMDB dataset.") | |
parser.add_argument( | |
"--config_path", | |
type=str, | |
default="config.yaml", | |
help="Path to the configuration file (e.g., config.yaml)" | |
) | |
parser.add_argument( | |
"--num_samples", | |
type=int, | |
default=5, | |
help="Number of samples from IMDB test set to run inference on." | |
) | |
args = parser.parse_args() | |
run_sample_inference(config_path=args.config_path, num_samples=args.num_samples) |