diff --git a/torchtitan/experiments/deepseek_v3/checkpoint.py b/torchtitan/experiments/deepseek_v3/checkpoint.py new file mode 100644 index 00000000..535ac7fe --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/checkpoint.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import os +from typing import Dict, Optional, Set, Tuple + +import torch +from safetensors import safe_open + +from transformers.utils import cached_file + + +logger = logging.getLogger(__name__) + +_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json" + + +def read_weights_from_json(file_path: str) -> Optional[Dict[str, str]]: + try: + with open(file_path, "r") as file: + data = json.load(file) + + if "weight_map" in data and isinstance(data["weight_map"], dict): + return data["weight_map"] + else: + logger.info("No 'weight_map' dictionary found in the JSON file.") + return None + except (json.JSONDecodeError, Exception) as e: + logger.info(f"An error occurred while reading the JSON file: {str(e)}") + return None + + +def get_hf_weight_map_and_path( + model_id: str, +) -> Tuple[Dict[str, str], str]: + """Get the weight map for a given HF model id and also the cache path for loading the weights""" + try: + index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME) + except Exception as e: + logger.error( + f"Model `{model_id}` not found in HF cache. " + f"You can download the model using `python download.py {model_id}" + ) + raise e + + weight_map = read_weights_from_json(index_file) + weight_path = os.path.dirname(index_file) + logger.info(f"Loading weights from: {weight_path}") + return weight_map, weight_path + + +def get_needed_files( + state_dict: Dict[str, torch.Tensor], weight_map: Dict[str, str] +) -> Set[str]: + needed_files = set() + for param in state_dict.keys(): + file = weight_map.get(param) + if file: + needed_files.add(file) + elif param.endswith("weight"): + raise ValueError( + f"Parameter {param} not found in weight map, please check..." + ) + logger.info(f"Needed files: {needed_files}") + return needed_files + + +def load_safetensor_file( + full_path: str, device: torch.device +) -> Dict[str, torch.Tensor]: + tensors = {} + with safe_open(full_path, framework="pt", device=device) as f: + for k in f.keys(): + tensors[k] = f.get_tensor(k) + logger.info(f"Loaded {len(tensors)} tensors from {full_path}") + return tensors + + +def load_safetensor_weights( + model: torch.nn.Module, + weight_map: Dict[str, str], + file_location: str, + device: torch.device, +): + """ + Load safetensor weights into a `nn.Module`. + + Args: + model (Module): The PyTorch module to load weights into. It may be a + model chunk or a full model. + weight_map (Dict[str, str]): Mapping of model parameters to file names. + file_location (str): Directory containing the weight files. + device (torch.device): The device to load tensors onto. + """ + model_state_dict = model.state_dict() + needed_files = get_needed_files(model_state_dict, weight_map) + updated_states: Set[str] = set() + + for file in needed_files: + full_path = os.path.join(file_location, file) + try: + checkpoint = load_safetensor_file(full_path, "cpu") + except FileNotFoundError: + logger.error(f"File not found: {full_path}") + except Exception as e: + logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}") + + matched_keys = set(checkpoint.keys()) & set(model_state_dict.keys()) + for key in matched_keys: + # Check shape + if model_state_dict[key].shape != checkpoint[key].shape: + raise ValueError( + f"Shape mismatch for {key}: " + f"model needs {model_state_dict[key].shape}, but " + f"checkpoint has {checkpoint[key].shape}" + ) + model_state_dict[key] = checkpoint[key].to(device) + + updated_states.update(matched_keys) + + missing_keys = set(model_state_dict.keys()) - updated_states + if missing_keys: + raise RuntimeError( + f"Partially updated state dict. Missing parameters: {missing_keys}" + ) + + model.load_state_dict(model_state_dict, strict=False, assign=True) + logger.info(f"Successfully loaded {len(updated_states)} weights into model") + + +def load_weights_from_hf( + model: torch.nn.Module, + distribution: str, + device: torch.device, +): + """ + Load the weights from Hugging Face format (index file + multiple safetensor + files), and fill into `model`. Model config is needed b/c we permute + wq and wk weights based on attn heads. + """ + + weight_map, weight_path = get_hf_weight_map_and_path(distribution) + + load_safetensor_weights( + model, + weight_map, + weight_path, + device, + ) diff --git a/torchtitan/experiments/deepseek_v3/download.py b/torchtitan/experiments/deepseek_v3/download.py new file mode 100644 index 00000000..d6d1f1b6 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/download.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Usage: +# python download.py {model_id} +# Example: +# python download.py deepseek-ai/DeepSeek-V2-Lite + +import sys + +from transformers import AutoModelForCausalLM + +model_id = sys.argv[1] + +model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", +) diff --git a/torchtitan/experiments/deepseek_v3/model.py b/torchtitan/experiments/deepseek_v3/model.py index 4cebfcdb..3ceb4bfa 100644 --- a/torchtitan/experiments/deepseek_v3/model.py +++ b/torchtitan/experiments/deepseek_v3/model.py @@ -27,7 +27,6 @@ # limitations under the License. """ PyTorch DeepSeek model.""" import math -from dataclasses import dataclass from typing import Optional, Tuple import numpy as np @@ -40,149 +39,9 @@ import torch.utils.checkpoint from attn_mask_utils import _prepare_4d_causal_attention_mask +from model_config import ModelArgs from symm_mem_recipes import on_device_all_to_all_v from torch import nn -from torch.nn import CrossEntropyLoss - - -@dataclass -class ModelArgs: - r""" - This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the DeepSeek-V3. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 129280): - Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`DeepseekV3Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - moe_intermediate_size (`int`, *optional*, defaults to 1407): - Dimension of the MoE representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_nextn_predict_layers (`int`, *optional*, defaults to 1): - Number of nextn predict layers in the DeepSeekV3 Model. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - n_shared_experts (`int`, *optional*, defaults to None): - Number of shared experts, None means dense model. - n_routed_experts (`int`, *optional*, defaults to None): - Number of routed experts, None means dense model. - routed_scaling_factor (`float`, *optional*, defaults to 1.0): - Scaling factor or routed experts. - topk_method (`str`, *optional*, defaults to `gready`): - Topk method used in routed gate. - n_group (`int`, *optional*, defaults to None): - Number of groups for routed experts. - topk_group (`int`, *optional*, defaults to None): - Number of selected groups for each token(for each token, ensuring the selected experts is only within - `topk_group` groups). - num_experts_per_tok (`int`, *optional*, defaults to None): - Number of selected experts, None means dense model. - moe_layer_freq (`int`, *optional*, defaults to 1): - The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. - first_k_dense_replace (`int`, *optional*, defaults to 0): - Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). - \--k dense layers--/ - norm_topk_prob (`bool`, *optional*, defaults to False): - Whether to normalize the weights of the routed experts. - scoring_func (`str`, *optional*, defaults to 'softmax'): - Method of computing expert weights. - aux_loss_alpha (`float`, *optional*, defaults to 0.001): - Auxiliary loss weight coefficient. - seq_aux = (`bool`, *optional*, defaults to True): - Whether to compute the auxiliary loss for each individual sample. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](/~https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is - `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - """ - - vocab_size: int = 129280 - hidden_size: int = 7168 - intermediate_size: int = 18432 - moe_intermediate_size: int = 2048 - num_hidden_layers: int = 61 - num_nextn_predict_layers: int = 1 - num_attention_heads: int = 128 - num_key_value_heads: int = 128 - n_shared_experts: int = 1 - n_routed_experts: int = 256 - ep_size: int = 1 - routed_scaling_factor: float = 2.5 - kv_lora_rank: int = 512 - q_lora_rank: int = 1536 - qk_rope_head_dim: int = 64 - v_head_dim: int = 128 - qk_nope_head_dim: int = 128 - topk_method: str = "noaux_tc" - n_group: int = 8 - topk_group: int = 4 - num_experts_per_tok: int = 8 - moe_layer_freq: int = 1 - first_k_dense_replace: int = 3 - norm_topk_prob: bool = True - scoring_func: str = "sigmoid" - aux_loss_alpha: float = 0.001 - seq_aux: bool = True - hidden_act: str = "silu" - max_position_embeddings: int = 4096 - initializer_range: float = 0.02 - rms_norm_eps: float = 1e-6 - rope_theta: float = 10000.0 - rope_scaling = None - attention_bias: bool = False - attention_dropout: float = 0.0 - pad_token_id = None - # Added for symmetric memory - max_seq_len: int = 4096 - # Added for pipeline parallel - num_stages: int = 1 - stage_idx: int = 0 # Get model parallel subgroup by name: @@ -192,15 +51,6 @@ def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup: return glob.get_group(dim_name) -# Get my pipeline parallel rank -def get_pp_rank() -> int: - try: - group = get_group("pp") - return group.rank() - except Exception: - return 0 - - class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() @@ -286,7 +136,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" @@ -533,6 +383,8 @@ def forward(self, hidden_states): ) if self.scoring_func == "sigmoid": scores = logits.sigmoid() + elif self.scoring_func == "softmax": + scores = logits.softmax(dim=-1, dtype=torch.float32) else: raise NotImplementedError( f"insupportable scoring function for MoE gating: {self.scoring_func}" @@ -567,6 +419,10 @@ def forward(self, hidden_states): ) # [n, e] _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) topk_weight = scores.gather(1, topk_idx) + elif self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) else: raise NotImplementedError( f"insupportable TopK function for MoE gating: {self.topk_method}" @@ -599,7 +455,6 @@ def __init__(self, config): assert config.ep_size == self.ep_group.size() self.ep_size = config.ep_size self.ep_rank = self.ep_group.rank() - print(f"Creating EP rank {self.ep_rank} of {self.ep_size}") self.experts_per_rank = config.n_routed_experts // config.ep_size self.experts = nn.ModuleList( [ @@ -677,7 +532,7 @@ def forward(self, hidden_states): @torch.no_grad() def moe_infer(self, x, topk_ids, topk_weight): - if not self.has_symm_mem: + if self.ep_size > 1 and (not self.has_symm_mem): # Set up symmetric memory for the first time, then reuse it self.setup_symm_mem(x.dtype, x.device) @@ -753,7 +608,10 @@ def moe_infer(self, x, topk_ids, topk_weight): outputs.append(expert_out) start_idx = end_idx - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + # len(outputs) == 0 means no tokens routed to this EP rank. + # `sorted_tokens` would have shape [0, hidden_dim], we use it so that + # `outs` is an empty tensor with shape [0, hidden_dim] + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens if self.ep_size > 1: # Take necessary space from `token_gather_buf` symm mem new_x = self.token_gather_buf[: outs.shape[0]] @@ -1074,7 +932,7 @@ def forward( return hidden_states -DeepseekV3_INPUTS_DOCSTRING = r""" +Deepseek_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -1144,7 +1002,7 @@ def forward( """ -class DeepseekV3Model(torch.nn.Module): +class DeepseekModel(torch.nn.Module): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`] @@ -1224,38 +1082,32 @@ def forward( return hidden_states -class DeepseekV3ForCausalLM(torch.nn.Module): +class DeepseekForCausalLM(torch.nn.Module): def __init__(self, config): - super().__init__(config) - self.model = DeepseekV3Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + super().__init__() + self.model = DeepseekModel(config) + self.lm_head = ( + nn.Linear(config.hidden_size, config.vocab_size, bias=False) + if config.stage_idx == config.num_stages - 1 + else None + ) # Initialize weights and apply final processing # self.post_init() def forward( self, - input_ids: torch.LongTensor, + tokens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, ) -> Tuple: r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. - - Returns: - Example: ```python - >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + >>> from transformers import AutoTokenizer, DeepseekForCausalLM - >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" @@ -1267,29 +1119,15 @@ def forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" hidden_states = self.model( - input_ids=input_ids, + tokens, attention_mask=attention_mask, position_ids=position_ids, ) - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - output = (logits,) - return (loss,) + output if loss is not None else output + logits = ( + self.lm_head(hidden_states) if self.lm_head is not None else hidden_states + ) + return logits def prepare_inputs_for_generation( self, @@ -1358,76 +1196,3 @@ def _reorder_cache(past_key_values, beam_idx): ), ) return reordered_past - - -# Start of testing part -# torchrun --standalone --nproc-per-node 4 model.py - -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.pipelining import PipelineStage, ScheduleGPipe - - -# Run full model -def run_full_model( - mesh: DeviceMesh, -): - rank = dist.get_rank() - device_count = torch.cuda.device_count() - device = torch.device("cuda", rank % device_count) - - pp_mesh = mesh["pp"] - ep_mesh = mesh["ep"] - pp_rank = pp_mesh.get_local_rank() - ep_rank = ep_mesh.get_local_rank() - pp_size = pp_mesh.size() - ep_size = ep_mesh.size() - - model_args = ModelArgs( - num_hidden_layers=3, - first_k_dense_replace=1, # activate MoE layers - ep_size=ep_size, # activate Expert Parallel - num_stages=pp_size, # activate Pipeline Parallel - stage_idx=pp_rank, # pipeline stage id - ) - print(model_args) - - # Instantiate model - with device, mesh: - model = DeepseekV3Model(model_args) - model.eval() - - # Example inputs - bs = 2 - microbatches = 2 - seqlen = 128 - x = torch.randint(model_args.vocab_size, (bs, seqlen), device=device) - - # Create pipeline stage - stage = PipelineStage( - model, - pp_rank, - pp_size, - device, - group=pp_mesh.get_group(), - ) - - # Create pipeline schedule - pp_schedule = ScheduleGPipe(stage, microbatches) - - # Run forward - if pp_rank == 0: - y = pp_schedule.step(x) - else: - y = pp_schedule.step() - - if pp_rank == pp_size - 1: - print(y.shape) - - -if __name__ == "__main__": - mesh = dist.init_device_mesh("cuda", (2, 2), mesh_dim_names=("pp", "ep")) - - with torch.no_grad(): - run_full_model(mesh) - - dist.destroy_process_group() diff --git a/torchtitan/experiments/deepseek_v3/model_config.py b/torchtitan/experiments/deepseek_v3/model_config.py new file mode 100644 index 00000000..fd7a340b --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/model_config.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass +class ModelArgs: + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within + `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](/~https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + """ + + vocab_size: int = 129280 + hidden_size: int = 7168 + intermediate_size: int = 18432 + moe_intermediate_size: int = 2048 + num_hidden_layers: int = 61 + num_nextn_predict_layers: int = 1 + num_attention_heads: int = 128 + num_key_value_heads: int = 128 + n_shared_experts: int = 1 + n_routed_experts: int = 256 + ep_size: int = 1 + routed_scaling_factor: float = 2.5 + kv_lora_rank: int = 512 + q_lora_rank: int = 1536 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + qk_nope_head_dim: int = 128 + topk_method: str = "noaux_tc" + n_group: int = 8 + topk_group: int = 4 + num_experts_per_tok: int = 8 + moe_layer_freq: int = 1 + first_k_dense_replace: int = 3 + norm_topk_prob: bool = True + scoring_func: str = "sigmoid" + aux_loss_alpha: float = 0.001 + seq_aux: bool = True + hidden_act: str = "silu" + max_position_embeddings: int = 4096 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000.0 + rope_scaling = None + attention_bias: bool = False + attention_dropout: float = 0.0 + pad_token_id = None + # Added for symmetric memory + max_seq_len: int = 4096 + # Added for pipeline parallel + num_stages: int = 1 + stage_idx: int = 0 + + +# This is the configuration for deepseek-ai/DeepSeek-V2-Lite. +deepseek_v2_lite_config = ModelArgs( + vocab_size=102400, + hidden_size=2048, + intermediate_size=10944, + moe_intermediate_size=1408, + num_hidden_layers=27, + num_attention_heads=16, + num_key_value_heads=16, + n_shared_experts=2, + n_routed_experts=64, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=None, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="greedy", + n_group=None, + topk_group=None, + num_experts_per_tok=6, + first_k_dense_replace=1, + norm_topk_prob=False, + scoring_func="softmax", + max_position_embeddings=2048, +) + + +# Model configuration registry +# Key is the model distribution ID on HuggingFace Hub +deepseek_config_registry = { + "deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config, + "deepseek-ai/deepseek-v3": ModelArgs(), +} diff --git a/torchtitan/experiments/deepseek_v3/run.py b/torchtitan/experiments/deepseek_v3/run.py new file mode 100644 index 00000000..637b4012 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/run.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# torchrun --standalone --nproc-per-node 4 run.py +import torch +import torch.distributed as dist +from checkpoint import load_weights_from_hf +from model import DeepseekForCausalLM +from model_config import deepseek_config_registry + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage, ScheduleGPipe + + +# Use DeepSeek-V2-Lite as a proxy +model_id = "deepseek-ai/DeepSeek-V2-Lite" + + +# Run full model +def run_full_model( + mesh: DeviceMesh, +): + rank = dist.get_rank() + device_count = torch.cuda.device_count() + device = torch.device("cuda", rank % device_count) + + pp_mesh = mesh["pp"] + ep_mesh = mesh["ep"] + pp_rank = pp_mesh.get_local_rank() + ep_rank = ep_mesh.get_local_rank() + pp_size = pp_mesh.size() + ep_size = ep_mesh.size() + + # Get model configs + model_args = deepseek_config_registry[model_id] + + # Apply model parallelism + model_args.ep_size = ep_size + model_args.num_stages = pp_size + model_args.stage_idx = pp_rank + print(model_args) + + # Instantiate model + with device, mesh: + model = DeepseekForCausalLM(model_args) + model.eval() + + # Load weights + load_weights_from_hf(model, model_id, device) + + # Example inputs + bs = 2 + microbatches = 2 + seqlen = 128 + x = torch.randint(model_args.vocab_size, (bs, seqlen), device=device) + + # Create pipeline stage + stage = PipelineStage( + model, + pp_rank, + pp_size, + device, + group=pp_mesh.get_group(), + ) + + # Create pipeline schedule + pp_schedule = ScheduleGPipe(stage, microbatches) + + # Run forward + if pp_rank == 0: + y = pp_schedule.step(x) + else: + y = pp_schedule.step() + + if pp_rank == pp_size - 1: + print(y.shape) + + +if __name__ == "__main__": + mesh = dist.init_device_mesh("cuda", (2, 2), mesh_dim_names=("pp", "ep")) + + with torch.no_grad(): + run_full_model(mesh) + + dist.destroy_process_group()