|
from app.core.cache import cache_config |
|
from datetime import datetime |
|
from typing import List, Dict, Any |
|
import datasets |
|
from fastapi import HTTPException |
|
import logging |
|
from app.config.hf_config import HF_TOKEN, HF_ORGANIZATION, HF_AGGREGATED |
|
from app.utils.logging import LogFormatter |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LeaderboardService: |
|
def __init__(self): |
|
pass |
|
|
|
async def fetch_raw_data(self) -> List[Dict[str, Any]]: |
|
"""Fetch raw leaderboard data from HuggingFace dataset""" |
|
try: |
|
logger.info(LogFormatter.section("FETCHING LEADERBOARD DATA")) |
|
logger.info( |
|
LogFormatter.info( |
|
f"Loading dataset from {HF_ORGANIZATION}/{HF_AGGREGATED}" |
|
) |
|
) |
|
|
|
dataset = datasets.load_dataset( |
|
f"{HF_ORGANIZATION}/{HF_AGGREGATED}", |
|
cache_dir=cache_config.get_cache_path("datasets"), |
|
)["train"] |
|
|
|
df = dataset.to_pandas() |
|
data = df.to_dict("records") |
|
|
|
stats = { |
|
"Total_Entries": len(data), |
|
"Dataset_Size": f"{df.memory_usage(deep=True).sum() / 1024 / 1024:.1f}MB", |
|
} |
|
for line in LogFormatter.stats(stats, "Dataset Statistics"): |
|
logger.info(line) |
|
|
|
return data |
|
|
|
except Exception as e: |
|
logger.error(LogFormatter.error("Failed to fetch leaderboard data", e)) |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
async def get_formatted_data(self) -> List[Dict[str, Any]]: |
|
"""Get formatted leaderboard data""" |
|
try: |
|
logger.info(LogFormatter.section("FORMATTING LEADERBOARD DATA")) |
|
|
|
raw_data = await self.fetch_raw_data() |
|
formatted_data = [] |
|
type_counts = {} |
|
error_count = 0 |
|
|
|
|
|
total_items = len(raw_data) |
|
logger.info(LogFormatter.info(f"Processing {total_items:,} entries...")) |
|
|
|
for i, item in enumerate(raw_data, 1): |
|
try: |
|
formatted_item = await self.transform_data(item) |
|
formatted_data.append(formatted_item) |
|
|
|
|
|
model_type = formatted_item["model"]["type"] |
|
type_counts[model_type] = type_counts.get(model_type, 0) + 1 |
|
|
|
except Exception as e: |
|
error_count += 1 |
|
logger.error( |
|
LogFormatter.error( |
|
f"Failed to format entry {i}/{total_items}", e |
|
) |
|
) |
|
continue |
|
|
|
|
|
if i % max(1, total_items // 10) == 0: |
|
progress = (i / total_items) * 100 |
|
logger.info( |
|
LogFormatter.info( |
|
f"Progress: {LogFormatter.progress_bar(i, total_items)}" |
|
) |
|
) |
|
|
|
|
|
stats = { |
|
"Total_Processed": total_items, |
|
"Successful": len(formatted_data), |
|
"Failed": error_count, |
|
} |
|
logger.info(LogFormatter.section("PROCESSING SUMMARY")) |
|
for line in LogFormatter.stats(stats, "Processing Statistics"): |
|
logger.info(line) |
|
|
|
|
|
type_stats = {f"Type_{k}": v for k, v in type_counts.items()} |
|
logger.info(LogFormatter.subsection("MODEL TYPE DISTRIBUTION")) |
|
for line in LogFormatter.stats(type_stats): |
|
logger.info(line) |
|
return formatted_data |
|
|
|
except Exception as e: |
|
logger.error(LogFormatter.error("Failed to format leaderboard data", e)) |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
def _calculate_average_score(self, data: Dict[str, Any]) -> float: |
|
scores = [] |
|
print(data) |
|
for key in [ |
|
"BC5CDR-chemical", |
|
"NCBI-disease", |
|
"ChemProt", |
|
"DDI2013", |
|
"HoC", |
|
"LitCovid", |
|
"MedQA", |
|
"PubMedQA", |
|
"PubMed", |
|
"MS^2", |
|
"Cochrane PLS", |
|
"PLOS", |
|
]: |
|
normalized = data.get(key, {}) |
|
if not isinstance(normalized, dict): |
|
continue |
|
|
|
for metric, value in normalized.items(): |
|
if metric in ["bart", "dcr", "fkg"]: |
|
continue |
|
if isinstance(value, (int, float)): |
|
scores.append(value) |
|
return sum(scores) / len(scores) if scores else 0.0 |
|
|
|
async def transform_data(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Transform raw data into the format expected by the frontend""" |
|
try: |
|
|
|
model_name = data.get("fullname", "Unknown") |
|
logger.debug( |
|
LogFormatter.info(f"Transforming data for model: {model_name}") |
|
) |
|
|
|
|
|
unique_id = f"{data.get('fullname', 'Unknown')}_{data.get('Precision', 'Unknown')}_{data.get('Model sha', 'Unknown')}_{str(data.get('Chat Template', False))}" |
|
evaluations = { |
|
"bc5cdr_chemical": { |
|
"name": "BC5CDR-chemical", |
|
"value": data.get("BC5CDR-chemical Raw", {}), |
|
"normalized_score": data.get("BC5CDR-chemical", {}), |
|
}, |
|
"ncbi_disease": { |
|
"name": "NCBI-disease", |
|
"value": data.get("NCBI-disease Raw", {}), |
|
"normalized_score": data.get("NCBI-disease", {}), |
|
}, |
|
"chemprot": { |
|
"name": "ChemProt", |
|
"value": data.get("ChemProt Raw", {}), |
|
"normalized_score": data.get("ChemProt", {}), |
|
}, |
|
"ddi2013": { |
|
"name": "DDI2013", |
|
"value": data.get("DDI2013 Raw", {}), |
|
"normalized_score": data.get("DDI2013", {}), |
|
}, |
|
"hoc": { |
|
"name": "HoC", |
|
"value": data.get("HoC Raw", {}), |
|
"normalized_score": data.get("HoC", {}), |
|
}, |
|
"litcovid": { |
|
"name": "LitCovid", |
|
"value": data.get("LitCovid Raw", {}), |
|
"normalized_score": data.get("LitCovid", {}), |
|
}, |
|
"medqa": { |
|
"name": "MedQA (5-Option)", |
|
"value": data.get("MedQA Raw", {}), |
|
"normalized_score": data.get("MedQA", {}), |
|
}, |
|
"pubmedqa": { |
|
"name": "PubMedQA", |
|
"value": data.get("PubMedQA Raw", {}), |
|
"normalized_score": data.get("PubMedQA", {}), |
|
}, |
|
"pubmed": { |
|
"name": "PubMed", |
|
"value": data.get("PubMed Raw", {}), |
|
"normalized_score": data.get("PubMed", {}), |
|
}, |
|
"ms2": { |
|
"name": "MS^2", |
|
"value": data.get("MS^2 Raw", {}), |
|
"normalized_score": data.get("MS^2", {}), |
|
}, |
|
"cochrane_pls": { |
|
"name": "Cochrane PLS", |
|
"value": data.get("Cochrane PLS Raw", {}), |
|
"normalized_score": data.get("Cochrane PLS", {}), |
|
}, |
|
"plos": { |
|
"name": "PLOS", |
|
"value": data.get("PLOS Raw", {}), |
|
"normalized_score": data.get("PLOS", {}), |
|
}, |
|
} |
|
|
|
features = { |
|
"is_not_available_on_hub": data.get("Available on the hub", False), |
|
"is_merged": data.get("Merged", False), |
|
"is_moe": data.get("MoE", False), |
|
"is_flagged": data.get("Flagged", False), |
|
"is_highlighted_by_maintainer": data.get("Official Providers", False), |
|
} |
|
|
|
metadata = { |
|
"upload_date": data.get("Upload To Hub Date"), |
|
"submission_date": data.get("Submission Date"), |
|
"generation": data.get("Generation"), |
|
"base_model": data.get("Base Model"), |
|
"hub_license": data.get("Hub License"), |
|
"hub_hearts": data.get("Hub ❤️"), |
|
"params_billions": data.get("#Params (B)"), |
|
"co2_cost": data.get("CO₂ cost (kg)", 0), |
|
} |
|
|
|
|
|
original_type = data.get("Type", "") |
|
model_type = original_type.lower().strip() |
|
|
|
|
|
if "(" in model_type: |
|
model_type = model_type.split("(")[0].strip() |
|
model_type = "".join(c for c in model_type if not c in "🔶🟢🟩💬🤝🌸 ") |
|
|
|
|
|
model_type_mapping = { |
|
"fine-tuned": "fined-tuned-on-domain-specific-dataset", |
|
"fine tuned": "fined-tuned-on-domain-specific-dataset", |
|
"finetuned": "fined-tuned-on-domain-specific-dataset", |
|
"fine_tuned": "fined-tuned-on-domain-specific-dataset", |
|
"ft": "fined-tuned-on-domain-specific-dataset", |
|
"finetuning": "fined-tuned-on-domain-specific-dataset", |
|
"fine tuning": "fined-tuned-on-domain-specific-dataset", |
|
"fine-tuning": "fined-tuned-on-domain-specific-dataset", |
|
} |
|
|
|
mapped_type = model_type_mapping.get(model_type.lower().strip(), model_type) |
|
|
|
if mapped_type != model_type: |
|
logger.debug( |
|
LogFormatter.info( |
|
f"Model type mapped: {original_type} -> {mapped_type}" |
|
) |
|
) |
|
|
|
transformed_data = { |
|
"id": unique_id, |
|
"model": { |
|
"name": data.get("fullname"), |
|
"sha": data.get("Model sha"), |
|
"precision": data.get("Precision"), |
|
"type": mapped_type, |
|
"weight_type": data.get("Weight type"), |
|
"architecture": data.get("Architecture"), |
|
"average_score": self._calculate_average_score(data), |
|
"has_chat_template": data.get("Chat Template", False), |
|
}, |
|
"evaluations": evaluations, |
|
"features": features, |
|
"metadata": metadata, |
|
} |
|
|
|
logger.debug( |
|
LogFormatter.success(f"Successfully transformed data for {model_name}") |
|
) |
|
return transformed_data |
|
|
|
except Exception as e: |
|
logger.error( |
|
LogFormatter.error( |
|
f"Failed to transform data for {data.get('fullname', 'Unknown')}", e |
|
) |
|
) |
|
raise |
|
|