File size: 5,921 Bytes
90c9a37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from typing import List,Dict
import re
def parse_model_entries(model_entries: List[str]) -> List[Dict[str, str]]:
"""
Parse a list of model entries into structured dictionaries with provider, model name, version, region, and type.
Args:
model_entries: List of model entry strings as found in models.txt
Returns:
List of dictionaries with parsed model information containing keys:
- provider: Name of the provider (e.g., 'azure', 'openai', 'anthropic', etc.)
- model_name: Base name of the model
- version: Version of the model (if available)
- region: Deployment region (if available)
- model_type: Type of the model (text, image, audio based on pattern analysis)
"""
parsed_models = []
# Common provider prefixes to identify
known_providers = [
'azure', 'bedrock', 'anthropic', 'openai', 'cohere', 'google',
'mistral', 'meta', 'amazon', 'ai21', 'anyscale', 'stability',
'cloudflare', 'databricks', 'cerebras', 'assemblyai'
]
# Image-related keywords to identify image models
image_indicators = ['dall-e', 'stable-diffusion', 'image', 'canvas', 'x-', 'steps']
# Audio-related keywords to identify audio models
audio_indicators = ['whisper', 'tts', 'audio', 'voice']
for entry in model_entries:
model_info = {
'provider': '',
'model_name': '',
'version': '',
'region': '',
'model_type': 'text' # Default to text
}
# Check for image models
if any(indicator in entry.lower() for indicator in image_indicators):
model_info['model_type'] = 'image'
# Check for audio models
elif any(indicator in entry.lower() for indicator in audio_indicators):
model_info['model_type'] = 'audio'
# Parse the entry based on common patterns
parts = entry.split('/')
# Handle region and provider extraction
if len(parts) >= 2:
# Extract provider from the beginning (common pattern)
if parts[0].lower() in known_providers:
model_info['provider'] = parts[0].lower()
# For bedrock and azure, the region is often the next part
if parts[0].lower() in ['bedrock', 'azure'] and len(parts) >= 3:
# Skip commitment parts if present
if 'commitment' not in parts[1]:
model_info['region'] = parts[1]
# The last part typically contains the model name and possibly version
model_with_version = parts[-1]
else:
# For single-part entries
model_with_version = entry
# Extract provider from model name if not already set
if not model_info['provider']:
# Look for known providers within the model name
for provider in known_providers:
if provider in model_with_version.lower() or f'{provider}.' in model_with_version.lower():
model_info['provider'] = provider
# Remove provider prefix if it exists at the beginning
if model_with_version.lower().startswith(f'{provider}.'):
model_with_version = model_with_version[len(provider) + 1:]
break
# Extract version information
version_match = re.search(r'[:.-]v(\d+(?:\.\d+)*(?:-\d+)?|\d+)(?::\d+)?$', model_with_version)
if version_match:
model_info['version'] = version_match.group(1)
# Remove version from model name
model_name = model_with_version[:version_match.start()]
else:
# Look for date-based versions like 2024-08-06
date_match = re.search(r'-(\d{4}-\d{2}-\d{2})$', model_with_version)
if date_match:
model_info['version'] = date_match.group(1)
model_name = model_with_version[:date_match.start()]
else:
model_name = model_with_version
# Clean up model name by removing trailing/leading separators
model_info['model_name'] = model_name.strip('.-:')
parsed_models.append(model_info)
return parsed_models
def create_model_hierarchy(model_entries: List[str]) -> Dict[str, Dict[str, Dict[str, Dict[str, str]]]]:
"""
Organize model entries into a nested dictionary structure by provider, model, version, and region.
Args:
model_entries: List of model entry strings as found in models.txt
Returns:
Nested dictionary with the structure:
Provider -> Model -> Version -> Region = full model string
If region or version is None, they are replaced with "NA".
"""
# Parse the model entries to get structured information
parsed_models = parse_model_entries(model_entries)
# Create the nested dictionary structure
hierarchy = {}
for i, model_info in enumerate(parsed_models):
provider = model_info['provider'] if model_info['provider'] else 'unknown'
model_name = model_info['model_name']
version = model_info['version'] if model_info['version'] else 'NA'
# For Azure models, always use 'NA' as region since they are globally available
region = 'NA' if provider == 'azure' else (model_info['region'] if model_info['region'] else 'NA')
# Initialize nested dictionaries if they don't exist
if provider not in hierarchy:
hierarchy[provider] = {}
if model_name not in hierarchy[provider]:
hierarchy[provider][model_name] = {}
if version not in hierarchy[provider][model_name]:
hierarchy[provider][model_name][version] = {}
# Store the full model string at the leaf node
hierarchy[provider][model_name][version][region] = model_entries[i]
return hierarchy |