|
import torch |
|
import torch.nn as nn |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
from diffusionLM.model.diffusionLM import LLaDAModel |
|
|
|
class DiffusionConfig(PretrainedConfig): |
|
"""Configuration class for Diffusion-LLM model.""" |
|
model_type = "diffusionLM" |
|
|
|
def __init__( |
|
self, |
|
vocab_size: int = 50257, |
|
hidden_size: int = 768, |
|
num_hidden_layers: int = 12, |
|
num_attention_heads: int = 12, |
|
intermediate_size: int = 3072, |
|
hidden_dropout_prob: float = 0.1, |
|
attention_probs_dropout_prob: float = 0.1, |
|
max_position_embeddings: int = 1024, |
|
initializer_range: float = 0.02, |
|
layer_norm_eps: float = 1e-12, |
|
pad_token_id: int = 0, |
|
mask_token_id: int = 50256, |
|
eos_token_id: int = 50256, |
|
num_timesteps: int = 100, |
|
time_embed_dim: int = 128, |
|
**kwargs |
|
): |
|
super().__init__(pad_token_id=pad_token_id, **kwargs) |
|
self.vocab_size = vocab_size |
|
self.hidden_size = hidden_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
self.intermediate_size = intermediate_size |
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob |
|
self.max_position_embeddings = max_position_embeddings |
|
self.initializer_range = initializer_range |
|
self.layer_norm_eps = layer_norm_eps |
|
self.mask_token_id = mask_token_id |
|
self.eos_token_id = eos_token_id |
|
self.num_timesteps = num_timesteps |
|
self.time_embed_dim = time_embed_dim |
|
|
|
class DiffusionLLM(PreTrainedModel): |
|
"""Main Diffusion-LLM model class""" |
|
config_class = DiffusionConfig |
|
base_model_prefix = "diffusionLM" |
|
|
|
def __init__(self, config: DiffusionConfig): |
|
super().__init__(config) |
|
self.model = LLaDAModel(config) |
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
timesteps=None, |
|
labels=None, |
|
return_dict=True, |
|
): |
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
timesteps=timesteps, |
|
labels=labels, |
|
) |
|
|
|
return outputs |
|
|
|
def generate( |
|
self, |
|
prompt=None, |
|
max_length=100, |
|
num_inference_steps=50, |
|
temperature=1.0, |
|
strategy='random', |
|
top_p=0.9, |
|
top_k=50, |
|
num_beams=5, |
|
return_scores=False, |
|
use_streaming=False, |
|
callback_fn=None |
|
): |
|
"""Unified generation interface""" |
|
if use_streaming: |
|
return self.generate_stream( |
|
prompt=prompt, |
|
max_length=max_length, |
|
num_inference_steps=num_inference_steps, |
|
temperature=temperature, |
|
strategy=strategy, |
|
top_p=top_p, |
|
top_k=top_k, |
|
num_beams=num_beams, |
|
callback_fn=callback_fn |
|
) |
|
else: |
|
return self.model.generate( |
|
prompt=prompt, |
|
max_length=max_length, |
|
num_inference_steps=num_inference_steps, |
|
temperature=temperature, |
|
strategy=strategy, |
|
top_p=top_p, |
|
top_k=top_k, |
|
num_beams=num_beams, |
|
return_scores=return_scores |
|
) |
|
|
|
def generate_stream(self, **kwargs): |
|
"""Streaming generation wrapper""" |
|
return self.model.generate_stream(**kwargs) |
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
"""Prepare inputs for generation compatibility""" |
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": kwargs.get("attention_mask", None), |
|
"timesteps": kwargs.get("timesteps", None), |
|
} |
|
|
|
@staticmethod |
|
def _reorder_cache(past, beam_idx): |
|
"""Reorder cache for beam search compatibility""" |
|
return past |
|
|