ash-98's picture
Initial Commit
d48ac80
from typing import List,Dict
import re
def parse_model_entries(model_entries: List[str]) -> List[Dict[str, str]]:
"""
Parse a list of model entries into structured dictionaries with provider, model name, version, region, and type.
Args:
model_entries: List of model entry strings as found in models.txt
Returns:
List of dictionaries with parsed model information containing keys:
- provider: Name of the provider (e.g., 'azure', 'openai', 'anthropic', etc.)
- model_name: Base name of the model
- version: Version of the model (if available)
- region: Deployment region (if available)
- model_type: Type of the model (text, image, audio based on pattern analysis)
"""
parsed_models = []
# Common provider prefixes to identify
known_providers = [
'azure', 'bedrock', 'anthropic', 'openai', 'cohere', 'google',
'mistral', 'meta', 'amazon', 'ai21', 'anyscale', 'stability',
'cloudflare', 'databricks', 'cerebras', 'assemblyai'
]
# Image-related keywords to identify image models
image_indicators = ['dall-e', 'stable-diffusion', 'image', 'canvas', 'x-', 'steps']
# Audio-related keywords to identify audio models
audio_indicators = ['whisper', 'tts', 'audio', 'voice']
for entry in model_entries:
model_info = {
'provider': '',
'model_name': '',
'version': '',
'region': '',
'model_type': 'text' # Default to text
}
# Check for image models
if any(indicator in entry.lower() for indicator in image_indicators):
model_info['model_type'] = 'image'
# Check for audio models
elif any(indicator in entry.lower() for indicator in audio_indicators):
model_info['model_type'] = 'audio'
# Parse the entry based on common patterns
parts = entry.split('/')
# Handle region and provider extraction
if len(parts) >= 2:
# Extract provider from the beginning (common pattern)
if parts[0].lower() in known_providers:
model_info['provider'] = parts[0].lower()
# For bedrock and azure, the region is often the next part
if parts[0].lower() in ['bedrock', 'azure'] and len(parts) >= 3:
# Skip commitment parts if present
if 'commitment' not in parts[1]:
model_info['region'] = parts[1]
# The last part typically contains the model name and possibly version
model_with_version = parts[-1]
else:
# For single-part entries
model_with_version = entry
# Extract provider from model name if not already set
if not model_info['provider']:
# Look for known providers within the model name
for provider in known_providers:
if provider in model_with_version.lower() or f'{provider}.' in model_with_version.lower():
model_info['provider'] = provider
# Remove provider prefix if it exists at the beginning
if model_with_version.lower().startswith(f'{provider}.'):
model_with_version = model_with_version[len(provider) + 1:]
break
# Extract version information
version_match = re.search(r'[:.-]v(\d+(?:\.\d+)*(?:-\d+)?|\d+)(?::\d+)?$', model_with_version)
if version_match:
model_info['version'] = version_match.group(1)
# Remove version from model name
model_name = model_with_version[:version_match.start()]
else:
# Look for date-based versions like 2024-08-06
date_match = re.search(r'-(\d{4}-\d{2}-\d{2})$', model_with_version)
if date_match:
model_info['version'] = date_match.group(1)
model_name = model_with_version[:date_match.start()]
else:
model_name = model_with_version
# Clean up model name by removing trailing/leading separators
model_info['model_name'] = model_name.strip('.-:')
parsed_models.append(model_info)
return parsed_models
def create_model_hierarchy(model_entries: List[str]) -> Dict[str, Dict[str, Dict[str, Dict[str, str]]]]:
"""
Organize model entries into a nested dictionary structure by provider, model, version, and region.
Args:
model_entries: List of model entry strings as found in models.txt
Returns:
Nested dictionary with the structure:
Provider -> Model -> Version -> Region = full model string
If region or version is None, they are replaced with "NA".
"""
# Parse the model entries to get structured information
parsed_models = parse_model_entries(model_entries)
# Create the nested dictionary structure
hierarchy = {}
for i, model_info in enumerate(parsed_models):
provider = model_info['provider'] if model_info['provider'] else 'unknown'
model_name = model_info['model_name']
version = model_info['version'] if model_info['version'] else 'NA'
# For Azure models, always use 'NA' as region since they are globally available
region = 'NA' if provider == 'azure' else (model_info['region'] if model_info['region'] else 'NA')
# Initialize nested dictionaries if they don't exist
if provider not in hierarchy:
hierarchy[provider] = {}
if model_name not in hierarchy[provider]:
hierarchy[provider][model_name] = {}
if version not in hierarchy[provider][model_name]:
hierarchy[provider][model_name][version] = {}
# Store the full model string at the leaf node
hierarchy[provider][model_name][version][region] = model_entries[i]
return hierarchy