Skip to content

Commit

Permalink
[float8] add float8 training benchmarking scripts (#1802)
Browse files Browse the repository at this point in the history
* add float8 training benchmarking scripts

* move to benchmarks/float8/training
  • Loading branch information
danielvegamyhre authored Mar 1, 2025
1 parent 8f93751 commit 7963f9c
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 0 deletions.
18 changes: 18 additions & 0 deletions benchmarks/float8/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Float8 training benchmarking

The `float8_training_benchmark.sh` script in this directory can be used to launch a Llama3 8b training run with [torchtitan](/~https://github.com/pytorch/torchtitan) training run, and parse the logs to calculate the median tokens/sec and peak memory usage for you.

## Usage

Example: `TORCHTITAN_ROOT=${HOME}/torchtitan FLOAT8_RECIPE=rowwise ./float8_training_benchmark.sh`

Training parameters can be configured via environment variables.

- Required:
- `TORCHTITAN_ROOT`
- Optional:
- `RECIPE`: rowwise|tensorwise. defaults to tensorwise.
- `BATCH_SIZE`: defaults to 1.
- `STEPS`: defaults to 100.

**NOTE**: `torch.compile` and FSDP2 are always used. Other forms of parallelism supported in torchtitan are not yet supported in this script.
47 changes: 47 additions & 0 deletions benchmarks/float8/training/float8_training_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash
# This script can be used to launch a torchtitan float8 training run
# with the given parameters,

# script arguments
BATCH_SIZE=${BATCH_SIZE:-1}
STEPS=${STEPS:-100}

# temporary log file which is deleted after performance data is parsed out and metrics are calculated.
LOG_FILE="/tmp/float8_training_log.txt"

# validate user has specified torchtitan root directory
if [ -z "${TORCHTITAN_ROOT}" ]; then
echo "Error: TORCHTITAN environment variable is not set. Please set it before running this script."
echo "Usage: TORCHTITAN_ROOT=<directory> ./float8_training_benchmark.sh"
echo "Optional parameters configurable via environment variables:"
echo " * FLOAT8_RECIPE: "rowwise" or "tensorwise". if set, use float8 training with the specified recipe. otherwise, use bf16 mixed precision training."
echo " * BATCH_SIZE: defaults to 1."
echo " * STEPS: defaults to 100."
exit 1
fi

# validate recipe name
if [ -n "${FLOAT8_RECIPE}" ]; then
FLOAT8_ARGS="--model.converters="float8" --float8.recipe_name=${FLOAT8_RECIPE}"
fi


# remember current directory to return to it later
original_dir=$(pwd)

# navigate to torchtitan root dir
cd ${TORCHTITAN_ROOT}

echo "float8 args: ${FLOAT8_ARGS}"

# run the command with the specified arguments
CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ${TORCHTITAN_ROOT}/run_train.sh --training.steps=${STEPS} --training.batch_size=${BATCH_SIZE} --training.compile ${FLOAT8_ARGS} 2>&1 | tee ${LOG_FILE}

# return to original working directory
cd $original_dir

# parse logs to calculate top line metrics
python parse_torchtitan_logs.py --log-file ${LOG_FILE}

# clean up logs
rm ${LOG_FILE}
57 changes: 57 additions & 0 deletions benchmarks/float8/training/parse_torchtitan_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
"""
Script which can be used to parse the log file generated by the torchtitan,
and calculate the training performance metrics (mdian tokens/second and peak memory usage).
Usage:
python parse_torchtitan_logs.py --log-file <log_file_path>
"""

import os
import re
import statistics
from argparse import ArgumentParser, Namespace


def main(args: Namespace):
print("\n=====================================================")
print(" Calculating training performance metrics")
print("=====================================================")

log_pattern = re.compile(r"step: (\d+).*?memory: ([\d.]+)GiB.*?tps: ([\d,]+)")

assert os.path.exists(args.log_file), f"{args.log_file} does not exist"

with open(args.log_file, "r") as f:
log_data = f.read()

matches = re.findall(log_pattern, log_data)

tokens_per_second = []
max_memory_usage = 0.0
for match in matches:
step = int(match[0])
memory_usage = float(match[1])
tps = float(match[2].replace(",", ""))

# update peak memory usage
max_memory_usage = max(max_memory_usage, memory_usage)

# collect tokens per second, excluding step 1 which has initialization overhead
if step != 1:
tokens_per_second.append(tps)

# calculate median tokens per second
median_tps = statistics.median(tokens_per_second) if tokens_per_second else 0

print(f"Median Tokens/Second (excluding step 1): {median_tps}")
print(f"Max Memory Usage: {max_memory_usage} GiB")


if __name__ == "__main__":
argparser = ArgumentParser()
argparser.add_argument(
"--log-file", type=str, required=True, help="torchtitan log file"
)
args = argparser.parse_args()
main(args)

0 comments on commit 7963f9c

Please sign in to comment.