import torch import torch.nn as nn from transformers import PretrainedConfig, PreTrainedModel from diffusionLM.model.diffusionLM import LLaDAModel class DiffusionConfig(PretrainedConfig): """Configuration class for Diffusion-LLM model.""" model_type = "diffusionLM" def __init__( self, vocab_size: int = 50257, hidden_size: int = 768, num_hidden_layers: int = 12, num_attention_heads: int = 12, intermediate_size: int = 3072, hidden_dropout_prob: float = 0.1, attention_probs_dropout_prob: float = 0.1, max_position_embeddings: int = 1024, initializer_range: float = 0.02, layer_norm_eps: float = 1e-12, pad_token_id: int = 0, mask_token_id: int = 50256, eos_token_id: int = 50256, num_timesteps: int = 100, time_embed_dim: int = 128, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.mask_token_id = mask_token_id self.eos_token_id = eos_token_id self.num_timesteps = num_timesteps self.time_embed_dim = time_embed_dim class DiffusionLLM(PreTrainedModel): """Main Diffusion-LLM model class""" config_class = DiffusionConfig base_model_prefix = "diffusionLM" def __init__(self, config: DiffusionConfig): super().__init__(config) self.model = LLaDAModel(config) self.init_weights() def forward( self, input_ids=None, attention_mask=None, timesteps=None, labels=None, return_dict=True, ): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, timesteps=timesteps, labels=labels, ) return outputs def generate( self, prompt=None, max_length=100, num_inference_steps=50, temperature=1.0, strategy='random', top_p=0.9, top_k=50, num_beams=5, return_scores=False, use_streaming=False, callback_fn=None ): """Unified generation interface""" if use_streaming: return self.generate_stream( prompt=prompt, max_length=max_length, num_inference_steps=num_inference_steps, temperature=temperature, strategy=strategy, top_p=top_p, top_k=top_k, num_beams=num_beams, callback_fn=callback_fn ) else: return self.model.generate( prompt=prompt, max_length=max_length, num_inference_steps=num_inference_steps, temperature=temperature, strategy=strategy, top_p=top_p, top_k=top_k, num_beams=num_beams, return_scores=return_scores ) def generate_stream(self, **kwargs): """Streaming generation wrapper""" return self.model.generate_stream(**kwargs) def prepare_inputs_for_generation(self, input_ids, **kwargs): """Prepare inputs for generation compatibility""" return { "input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None), "timesteps": kwargs.get("timesteps", None), } @staticmethod def _reorder_cache(past, beam_idx): """Reorder cache for beam search compatibility""" return past