sesame_openai / app /generator.py
karumati's picture
yo
01115c6
# Updated generator.py with proper function order
from dataclasses import dataclass
from typing import List, Tuple
import torch
import torchaudio
import logging
import os
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from tokenizers.processors import TemplateProcessing
from app.models import Segment
from app.text_normalizer import clean_text_for_tts
from app.text_normalizer import TextNormalizer
# Set up logging
logger = logging.getLogger(__name__)
# Import the CSM watermarking code
try:
from app.watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
except ImportError:
# Define stubs for watermarking if the module is not available
CSM_1B_GH_WATERMARK = "CSM1B"
def load_watermarker(device="cpu"):
return None
def watermark(watermarker, audio, sample_rate, key):
return audio, sample_rate
def load_llama3_tokenizer():
"""
Load tokenizer for Llama 3.2, using unsloth's open version
instead of the gated meta-llama version.
"""
try:
# Use the unsloth version which is not gated
tokenizer_name = "unsloth/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token
eos = tokenizer.eos_token
tokenizer._tokenizer.post_processor = TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
)
logger.info("Successfully loaded tokenizer from unsloth/Llama-3.2-1B")
return tokenizer
except Exception as e:
logger.error(f"Error loading tokenizer from unsloth: {e}")
# Fallback to a simpler tokenizer if needed
try:
from transformers import GPT2Tokenizer
logger.warning("Falling back to GPT2Tokenizer")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
except Exception as fallback_e:
logger.error(f"Fallback tokenizer also failed: {fallback_e}")
raise RuntimeError("Could not load any suitable tokenizer")
class Generator:
"""Generator class for CSM-1B model."""
def __init__(self, model):
"""Initialize generator with model."""
self._model = model
self._model.setup_caches(1)
self._text_tokenizer = load_llama3_tokenizer()
device = next(model.parameters()).device
# Load Mimi codec for audio tokenization
try:
logger.info("Loading Mimi audio codec...")
from huggingface_hub import hf_hub_download
# First try to import from moshi
try:
from moshi.models import loaders
DEFAULT_REPO = loaders.DEFAULT_REPO
MIMI_NAME = loaders.MIMI_NAME
get_mimi = loaders.get_mimi
except ImportError:
logger.warning("moshi.models.loaders not found, using fallback")
# Fallback values if moshi.models.loaders is not available
DEFAULT_REPO = "kyutai/mimi"
MIMI_NAME = "mimi-december.pt"
# Fallback function to load mimi
def get_mimi(checkpoint_path, device):
from moshi.models.vqvae_model import MiMiModule
checkpoint = torch.load(checkpoint_path, map_location=device)
model = MiMiModule.init_from_checkpoint(checkpoint, device=device)
return model
mimi_weight = hf_hub_download(DEFAULT_REPO, MIMI_NAME)
mimi = get_mimi(mimi_weight, device=device)
mimi.set_num_codebooks(32)
self._audio_tokenizer = mimi
self.sample_rate = mimi.sample_rate
logger.info(f"Mimi codec loaded successfully with sample rate {self.sample_rate}")
except Exception as e:
logger.error(f"Error loading Mimi codec: {e}")
self._audio_tokenizer = None
self.sample_rate = 24000 # Default sample rate
logger.warning(f"Using fallback sample rate: {self.sample_rate}")
raise RuntimeError(f"Failed to load Mimi codec: {e}")
try:
self._watermarker = load_watermarker(device=device)
logger.info("Watermarker loaded successfully")
except Exception as e:
logger.warning(f"Error loading watermarker: {e}. Watermarking will be disabled.")
self._watermarker = None
self.device = device
# Optimize for CUDA throughput
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()
logger.info("CUDA optimizations enabled")
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Tokenize a text segment."""
frame_tokens = []
frame_masks = []
# Strip any voice instructions in square brackets to avoid them being read out
text = self._clean_text_input(text)
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
text_frame = torch.zeros(len(text_tokens), 33).long()
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
text_frame[:, -1] = torch.tensor(text_tokens)
text_frame_mask[:, -1] = True
frame_tokens.append(text_frame.to(self.device))
frame_masks.append(text_frame_mask.to(self.device))
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
def _clean_text_input(self, text: str) -> str:
"""Clean and normalize text for TTS."""
return clean_text_for_tts(text)
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Tokenize audio."""
if self._audio_tokenizer is None:
raise RuntimeError("Audio tokenizer not initialized")
frame_tokens = []
frame_masks = []
# (K, T)
audio = audio.to(self.device)
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
# add EOS frame
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
audio_frame_mask[:, :-1] = True
frame_tokens.append(audio_frame)
frame_masks.append(audio_frame_mask)
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
"""Tokenize a segment of text and audio."""
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
def generate_quick(
self,
text: str,
speaker: int,
context: List[Segment],
max_audio_length_ms: float = 2000, # Short for quick generation
temperature: float = 0.7, # Lower for more predictable output
topk: int = 20, # Lower for faster beam selection
) -> torch.Tensor:
"""Generate audio quickly for real-time streaming."""
# Similar to generate() but optimized for speed
self._model.reset_caches()
# Convert max_audio_length_ms to frames - limit for faster generation
max_audio_frames = min(int(max_audio_length_ms / 80), 128) # Smaller limit
# Process text
cleaned_text = clean_text_for_tts(text)
# Prepare tokens
tokens, tokens_mask = [], []
# Add context segments (limited to 1 for speed)
if context:
segment_tokens, segment_tokens_mask = self._tokenize_segment(context[0])
tokens.append(segment_tokens)
tokens_mask.append(segment_tokens_mask)
# Add text tokens
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
tokens.append(gen_segment_tokens)
tokens_mask.append(gen_segment_tokens_mask)
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
# Generate with larger batch size for fewer iterations
curr_tokens = prompt_tokens.unsqueeze(0)
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
# Use larger batch size
batch_size = 64 # Generate more frames at once
all_samples = []
for start_idx in range(0, max_audio_frames, batch_size):
end_idx = min(start_idx + batch_size, max_audio_frames)
batch_frames = end_idx - start_idx
samples_batch = []
for i in range(batch_frames):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
samples_batch.append(sample)
if torch.all(sample == 0):
break
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
curr_tokens_mask = torch.cat(
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1
all_samples.extend(samples_batch)
if len(samples_batch) < batch_frames:
break
if not all_samples:
return torch.zeros(10, device=self.device) # Return short empty audio
# Decode audio
audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
return audio
@torch.inference_mode()
def generate(
self,
text: str,
speaker: int,
context: List[Segment],
max_audio_length_ms: float = 90_000,
temperature: float = 0.9,
topk: int = 50,
) -> torch.Tensor:
"""Generate audio from text."""
if self._audio_tokenizer is None:
raise RuntimeError("Audio tokenizer not initialized")
# Start timing
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
self._model.reset_caches()
# Convert max_audio_length_ms to frames - this controls the maximum generation length
max_audio_frames = min(int(max_audio_length_ms / 80), 1024) # Limit to reasonable size
max_seq_len = 2048 - max_audio_frames
# Check if text is long and should be split
if len(text) > 200:
logger.info(f"Long text detected ({len(text)} chars), processing in segments")
sentences = TextNormalizer.split_into_sentences(text)
logger.info(f"Split into {len(sentences)} segments")
# Process sentences individually and concatenate the results
all_audio_segments = []
# Use the first sentence to establish voice
first_sentence = sentences[0]
cleaned_text = clean_text_for_tts(first_sentence)
# Generate the first segment
tokens, tokens_mask = [], []
# Add context segments for the first sentence
for segment in context:
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
tokens.append(segment_tokens)
tokens_mask.append(segment_tokens_mask)
# Add first sentence tokens
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
tokens.append(gen_segment_tokens)
tokens_mask.append(gen_segment_tokens_mask)
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
# Check context size and truncate if needed
if prompt_tokens.size(0) >= max_seq_len:
logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}")
prompt_tokens = prompt_tokens[-max_seq_len+50:]
prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:]
# Generate first sentence audio
curr_tokens = prompt_tokens.unsqueeze(0)
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
# Generate first segment
first_segment_samples = []
for start_idx in range(0, max_audio_frames, 32):
end_idx = min(start_idx + 32, max_audio_frames)
batch_frames = end_idx - start_idx
samples_batch = []
for i in range(batch_frames):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
samples_batch.append(sample)
if torch.all(sample == 0):
break
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
curr_tokens_mask = torch.cat(
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1
first_segment_samples.extend(samples_batch)
if len(samples_batch) < batch_frames:
break
if not first_segment_samples:
raise RuntimeError("No audio generated for first segment")
# Decode first segment
first_segment_audio = self._audio_tokenizer.decode(
torch.stack(first_segment_samples).permute(1, 2, 0)
).squeeze(0).squeeze(0)
all_audio_segments.append(first_segment_audio)
# Now process remaining sentences using the first as context
for i, sentence in enumerate(sentences[1:], 1):
logger.info(f"Generating segment {i+1}/{len(sentences)}")
cleaned_text = clean_text_for_tts(sentence)
# Create a context segment from the previous generation
prev_segment = Segment(
speaker=speaker,
text=sentences[i-1],
audio=all_audio_segments[-1]
)
# Generate with this segment as context
segment_tokens, segment_tokens_mask = [], []
segment_tokens.append(self._tokenize_segment(prev_segment)[0])
segment_tokens_mask.append(self._tokenize_segment(prev_segment)[1])
# Add current segment tokens
current_tokens, current_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
segment_tokens.append(current_tokens)
segment_tokens_mask.append(current_tokens_mask)
segment_prompt_tokens = torch.cat(segment_tokens, dim=0).long().to(self.device)
segment_prompt_tokens_mask = torch.cat(segment_tokens_mask, dim=0).bool().to(self.device)
# Check length and truncate if needed
if segment_prompt_tokens.size(0) >= max_seq_len:
segment_prompt_tokens = segment_prompt_tokens[-max_seq_len+50:]
segment_prompt_tokens_mask = segment_prompt_tokens_mask[-max_seq_len+50:]
# Generate audio for this segment
curr_tokens = segment_prompt_tokens.unsqueeze(0)
curr_tokens_mask = segment_prompt_tokens_mask.unsqueeze(0)
curr_pos = torch.arange(0, segment_prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
# Generate segment
segment_samples = []
for start_idx in range(0, max_audio_frames, 32):
end_idx = min(start_idx + 32, max_audio_frames)
batch_frames = end_idx - start_idx
samples_batch = []
for i in range(batch_frames):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
samples_batch.append(sample)
if torch.all(sample == 0):
break
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
curr_tokens_mask = torch.cat(
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1
segment_samples.extend(samples_batch)
if len(samples_batch) < batch_frames:
break
if not segment_samples:
logger.warning(f"No audio generated for segment {i+1}")
continue
# Decode segment
segment_audio = self._audio_tokenizer.decode(
torch.stack(segment_samples).permute(1, 2, 0)
).squeeze(0).squeeze(0)
all_audio_segments.append(segment_audio)
# Combine all segments with small pauses
pause_samples = int(0.3 * self.sample_rate) # 300ms pause
pause = torch.zeros(pause_samples, device=self.device)
audio_parts = []
for i, segment_audio in enumerate(all_audio_segments):
audio_parts.append(segment_audio)
if i < len(all_audio_segments) - 1:
audio_parts.append(pause)
audio = torch.cat(audio_parts)
logger.info(f"Combined {len(all_audio_segments)} segments into final audio")
else:
# For shorter text, standard processing
tokens, tokens_mask = [], []
# Add context segments
for segment in context:
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
tokens.append(segment_tokens)
tokens_mask.append(segment_tokens_mask)
# Process text
cleaned_text = clean_text_for_tts(text)
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
tokens.append(gen_segment_tokens)
tokens_mask.append(gen_segment_tokens_mask)
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
# Check context size
if prompt_tokens.size(0) >= max_seq_len:
logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}")
prompt_tokens = prompt_tokens[-max_seq_len+50:]
prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:]
# Generate audio - optimized batch generation
curr_tokens = prompt_tokens.unsqueeze(0)
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
# Using optimized batch generation
batch_size = 32 # Generate this many frames at once
all_samples = []
for start_idx in range(0, max_audio_frames, batch_size):
end_idx = min(start_idx + batch_size, max_audio_frames)
batch_frames = end_idx - start_idx
samples_batch = []
for i in range(batch_frames):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
samples_batch.append(sample)
if torch.all(sample == 0):
break
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
curr_tokens_mask = torch.cat(
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1
all_samples.extend(samples_batch)
if len(samples_batch) < batch_frames:
logger.info(f"Early stopping at frame {start_idx + len(samples_batch)}/{max_audio_frames}")
break
if not all_samples:
raise RuntimeError("No audio generated - model produced empty output")
# Decode audio
audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
# Apply watermark
if self._watermarker is not None:
try:
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
except Exception as e:
logger.warning(f"Error applying watermark: {e}. Continuing without watermark.")
# Record execution time
end_time.record()
torch.cuda.synchronize()
execution_ms = start_time.elapsed_time(end_time)
audio_length_ms = (audio.shape[0] / self.sample_rate) * 1000
# Calculate real-time factor (RTF)
rtf = execution_ms / audio_length_ms
logger.info(f"Audio generated in {execution_ms:.2f}ms, length: {audio_length_ms:.2f}ms, RTF: {rtf:.2f}x")
return audio
# Define helper functions for multi-GPU support
def _manual_device_map(model, state_dict, strategy="balanced"):
"""Apply manual device mapping for multi-GPU setups.
Args:
model: The model to map
state_dict: Model state dict
strategy: Mapping strategy ('balanced', 'sequential')
Returns:
Model with weights distributed across GPUs
"""
num_gpus = torch.cuda.device_count()
if num_gpus <= 1:
# No need for mapping with single GPU
model.load_state_dict(state_dict)
model = model.to("cuda")
return model
logger.info(f"Applying manual {strategy} device mapping across {num_gpus} GPUs")
# Get all layer names from state dict
layer_names = [name for name in state_dict.keys() if "layers" in name]
backbone_layers = [name for name in layer_names if "backbone.layers" in name]
decoder_layers = [name for name in layer_names if "decoder.layers" in name]
# Count number of backbone and decoder layers
backbone_layer_indices = set()
for name in backbone_layers:
parts = name.split('.')
if len(parts) > 2:
try:
backbone_layer_indices.add(int(parts[2]))
except ValueError:
pass
decoder_layer_indices = set()
for name in decoder_layers:
parts = name.split('.')
if len(parts) > 2:
try:
decoder_layer_indices.add(int(parts[2]))
except ValueError:
pass
num_backbone_layers = len(backbone_layer_indices)
num_decoder_layers = len(decoder_layer_indices)
# Create device map
device_map = {}
if strategy == "balanced":
# Distribute layers evenly across GPUs
layers_per_gpu = (num_backbone_layers + num_decoder_layers) // num_gpus
remainder = (num_backbone_layers + num_decoder_layers) % num_gpus
# Assign backbone layers
for i in backbone_layer_indices:
gpu_idx = min(i // layers_per_gpu, num_gpus - 1)
device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}"
# Assign decoder layers
for i in decoder_layer_indices:
gpu_idx = min((i + num_backbone_layers) // layers_per_gpu, num_gpus - 1)
device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}"
elif strategy == "sequential":
# Fill each GPU sequentially
# Backbone layers on first GPU(s)
backbone_per_gpu = max(1, num_backbone_layers // ((num_gpus + 1) // 2))
for i in backbone_layer_indices:
gpu_idx = min(i // backbone_per_gpu, (num_gpus + 1) // 2 - 1)
device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}"
# Decoder layers on remaining GPU(s)
decoder_per_gpu = max(1, num_decoder_layers // (num_gpus - (num_gpus + 1) // 2 + 1))
for i in decoder_layer_indices:
gpu_idx = min(i // decoder_per_gpu + (num_gpus + 1) // 2 - 1, num_gpus - 1)
device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}"
# Assign embeddings and other components
device_map["text_embeddings"] = "cuda:0"
device_map["audio_embeddings"] = "cuda:0"
device_map["projection"] = "cuda:0"
device_map["codebook0_head"] = "cuda:0"
device_map["audio_head"] = "cuda:0"
# Load state dict with device mapping
model.load_state_dict(state_dict)
# Move model parts to assigned devices
for name, device in device_map.items():
if "backbone.layers" in name:
layer_idx = int(name.split('.')[-1])
if hasattr(model.backbone, 'layers') and layer_idx < len(model.backbone.layers):
model.backbone.layers[layer_idx] = model.backbone.layers[layer_idx].to(device)
elif "decoder.layers" in name:
layer_idx = int(name.split('.')[-1])
if hasattr(model.decoder, 'layers') and layer_idx < len(model.decoder.layers):
model.decoder.layers[layer_idx] = model.decoder.layers[layer_idx].to(device)
elif hasattr(model, name):
setattr(model, name, getattr(model, name).to(device))
logger.info(f"Model distributed across GPUs with {strategy} strategy")
return model
def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda", device_map: str = None) -> Generator:
"""Load CSM-1B model and create generator with performance optimizations.
Args:
ckpt_path: Path to model checkpoint
device: Device to load model on ('cuda', 'cpu', or specific CUDA device)
device_map: Optional device mapping strategy ('auto', 'balanced', 'sequential', or None)
Returns:
Generator instance with optimized settings
"""
try:
# Import models module for CSM
from app.torchtune_models import Model, ModelArgs
# Create model
model_args = ModelArgs(
backbone_flavor="llama-1B",
decoder_flavor="llama-100M",
text_vocab_size=128256,
audio_vocab_size=2051,
audio_num_codebooks=32,
)
# Load model
logger.info(f"Loading CSM-1B model from {ckpt_path} with device={device}, device_map={device_map}")
# Check for CUDA availability
cuda_available = device == "cuda" and torch.cuda.is_available()
# Set up torch for optimized inference
if cuda_available:
# Check if we should enable TF32 (faster but slightly less precise)
enable_tf32 = os.environ.get("ENABLE_TF32", "true").lower() == "true"
if enable_tf32:
logger.info("Enabling TF32 for faster matrix multiplications")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Check for available precision modes
use_bfloat16 = torch.cuda.is_bf16_supported()
use_float16 = not use_bfloat16 and torch.cuda.is_available() # Fallback to float16
if use_bfloat16:
dtype = torch.bfloat16
logger.info("Using bfloat16 precision for faster inference")
elif use_float16:
dtype = torch.float16
logger.info("Using float16 precision for faster inference")
else:
dtype = torch.float32
logger.info("Using float32 precision (mixed precision not available)")
# Enable Flash Attention if available
try:
import flash_attn
if os.environ.get("ENABLE_FLASH_ATTN", "true").lower() == "true":
logger.info("Flash Attention detected - enabling for faster attention")
os.environ["PYTORCH_FLASH_ATTENTION_ENABLED"] = "1"
except ImportError:
logger.info("Flash Attention not available (install flash-attn for faster inference)")
else:
# CPU-only mode
dtype = torch.float32
logger.info("Using CPU mode with float32 precision")
# Check for quantization
enable_quantization = os.environ.get("ENABLE_QUANTIZATION", "false").lower() == "true"
is_quantized = False
# Check for multi-GPU setup
if device_map and torch.cuda.device_count() > 1:
logger.info(f"Using device_map={device_map} across {torch.cuda.device_count()} GPUs")
# Create model with device map
model = Model(model_args)
# Load state dict
state_dict = torch.load(ckpt_path, map_location='cpu')
# Try quantization before device mapping if enabled
if enable_quantization and cuda_available:
try:
from bitsandbytes.nn import Linear8bitLt
def replace_with_8bit(model):
"""Replace linear layers with 8-bit quantized versions"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and module.out_features > 256:
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
parent = model
if parent_name:
for attr in parent_name.split('.'):
parent = getattr(parent, attr)
child_name = name.rsplit('.', 1)[1] if '.' in name else name
setattr(parent, child_name, Linear8bitLt.from_float(module))
return model
logger.info("Applying 8-bit quantization to linear layers")
model = replace_with_8bit(model)
is_quantized = True
except ImportError:
logger.warning("bitsandbytes not available, skipping quantization")
# Apply device mapping
if device_map == "auto":
# Use accelerate for automatic device mapping
try:
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
# Initialize empty model
with init_empty_weights():
empty_model = Model(model_args)
# Load and dispatch model across GPUs
model = load_checkpoint_and_dispatch(
empty_model,
ckpt_path,
device_map="auto",
no_split_module_classes=["TransformerLayer"],
# Offload CPU if very large model
offload_folder="offload" if os.environ.get("OFFLOAD_TO_CPU", "false").lower() == "true" else None
)
logger.info("Model loaded with automatic device mapping")
except ImportError:
logger.warning("accelerate package not found, falling back to manual device mapping")
model = _manual_device_map(model, state_dict, "balanced")
except Exception as mapping_error:
logger.error(f"Auto device mapping failed: {mapping_error}, falling back to manual")
model = _manual_device_map(model, state_dict, "balanced")
else:
# Manual device mapping
model = _manual_device_map(model, state_dict, device_map or "balanced")
else:
# Single GPU or CPU setup
# Try quantization before loading if enabled (GPU only)
if enable_quantization and cuda_available and not is_quantized:
try:
# First load to CPU for quantization
model = Model(model_args).to("cpu")
state_dict = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state_dict)
from bitsandbytes.nn import Linear8bitLt
def replace_with_8bit(model):
"""Replace linear layers with 8-bit quantized versions"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and module.out_features > 256:
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
parent = model
if parent_name:
for attr in parent_name.split('.'):
parent = getattr(parent, attr)
child_name = name.rsplit('.', 1)[1] if '.' in name else name
setattr(parent, child_name, Linear8bitLt.from_float(module))
return model
logger.info("Applying 8-bit quantization to linear layers")
model = replace_with_8bit(model)
model = model.to(device=device)
is_quantized = True
except ImportError:
logger.warning("bitsandbytes not available, loading without quantization")
# Load the standard way
model = Model(model_args).to(device=device, dtype=dtype)
state_dict = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state_dict)
except Exception as quant_error:
logger.error(f"Quantization failed: {quant_error}, loading without quantization")
# Load the standard way
model = Model(model_args).to(device=device, dtype=dtype)
state_dict = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state_dict)
else:
# Standard load without quantization
model = Model(model_args).to(device=device, dtype=dtype)
state_dict = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state_dict)
# Apply torch.compile if available (PyTorch 2.0+)
compile_mode = os.environ.get("TORCH_COMPILE_MODE", "none")
if hasattr(torch, 'compile') and compile_mode != "none" and cuda_available:
try:
logger.info(f"Using torch.compile with mode '{compile_mode}' for faster inference")
if compile_mode == "default":
model = torch.compile(model)
else:
model = torch.compile(model, mode=compile_mode)
except Exception as compile_error:
logger.warning(f"Torch compile failed (requires PyTorch 2.0+): {compile_error}")
# Try to optimize CUDA graphs for faster inference (advanced)
use_cuda_graphs = os.environ.get("USE_CUDA_GRAPHS", "false").lower() == "true"
if use_cuda_graphs and cuda_available and hasattr(torch.cuda, 'CUDAGraph'):
try:
logger.info("Setting up CUDA graphs for repeated inference patterns")
# This requires custom integration inside the model's forward method
# Just flagging that CUDA graphs should be used
model.use_cuda_graphs = True
except Exception as cuda_graph_error:
logger.warning(f"CUDA graphs setup failed: {cuda_graph_error}")
model.use_cuda_graphs = False
# Set optimal settings for CUDA context
if cuda_available:
# Set benchmark mode for hardware-specific optimizations
torch.backends.cudnn.benchmark = True
# Clean up CUDA cache before creating generator
torch.cuda.empty_cache()
# Ensure all CUDA work is completed to avoid launch delays
torch.cuda.synchronize()
# Create generator
logger.info("Creating generator with optimized settings")
generator = Generator(model)
# Log memory usage if on CUDA
if cuda_available:
memory_allocated = torch.cuda.memory_allocated() / (1024**3)
memory_reserved = torch.cuda.memory_reserved() / (1024**3)
logger.info(f"Model loaded, CUDA memory: {memory_allocated:.2f}GB allocated, {memory_reserved:.2f}GB reserved")
logger.info(f"Generator created successfully: precision={dtype}, quantized={is_quantized}")
return generator
except Exception as e:
logger.error(f"Failed to load CSM-1B model: {e}")
import traceback
logger.error(traceback.format_exc())
raise