Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring LLaMa 3.1 405B to TorchTitan family #481

Merged
merged 12 commits into from
Aug 1, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Once you have confirmed access, you can run the following command to download th
```bash
# Get your HF token from https://huggingface.co/settings/tokens

# llama3 tokenizer.model
# llama3 or 3.1 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...

# llama2 tokenizer.model
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/datasets/download_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def hf_download(

try:
hf_hub_download(
repo_id,
tokenizer_path,
repo_id=repo_id,
filename=tokenizer_path,
local_dir=local_dir,
local_dir_use_symlinks=False,
token=hf_token,
Expand Down
9 changes: 9 additions & 0 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,13 @@
multiple_of=4096,
rope_theta=500000,
),
"405B": ModelArgs(
dim=16384,
n_layers=126,
n_heads=128,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=4096,
rope_theta=500000,
),
}
53 changes: 53 additions & 0 deletions train_configs/llama3_405b.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# torchtitan Config.toml
# NOTE: this toml config is a preset for 128 H100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 405B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "405B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 0.8e-4

[training]
batch_size = 2
seq_len = 8192
warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 3000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
enable_float8_linear = false
compile = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if there were any blockers on enabling compile?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. Any reason why we didn't enable it for 70B? @tianyu-l For 405b, the ideal toml will have 3D for sure. Whether to have compile enabled is something we can discuss.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously there were issues with compile. Now there's no blocker, maybe just extra compile time?
BTW compiler could potentially help reduce memory so that 405B 2D can be trained on 128 A100 GPUs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me try with compile if it works, I will update the toml file in a different PR.

dataset = "c4"

[experimental]
pipeline_parallel_degree = 1

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'full' # ['none', 'selective', 'full']
Loading