Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update JetStream instructions #132

Merged
merged 5 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 109 additions & 39 deletions docs/online-inference-with-maxtext-engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ Follow the steps in [Manage TPU resources | Google Cloud](https://cloud.google.c
## Step 1: Download JetStream and the MaxText github repository

```bash
git clone -b jetstream-v0.2.2 /~https://github.com/google/maxtext.git
git clone -b v0.2.2 /~https://github.com/google/JetStream.git
git clone /~https://github.com/google/maxtext.git
git clone /~https://github.com/google/JetStream.git
```

## Step 2: Setup MaxText
Expand All @@ -45,16 +45,16 @@ You can run the JetStream MaxText Server with Gemma and Llama2 models. This sect
### Use a Gemma model checkpoint

* You can download a [Gemma checkpoint from Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText/variations/7b).
* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`.
* After downloading orbax Gemma checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively.
* `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}`
* Please refer to the [conversion script](/~https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`.
* Then, using the following command to convert the Gemma checkpoint into a MaxText compatible unscanned checkpoint.

```bash
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET}
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED}

# For gemma-7b
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED}
```

Note: For more information about the Gemma model and checkpoints, see [About Gemma](/~https://github.com/google/maxtext/blob/main/end_to_end/gemma/Run_Gemma.md).
Expand All @@ -63,25 +63,25 @@ Note: For more information about the Gemma model and checkpoints, see [About Gem
### Use a Llama2 model checkpoint

* You can use a Llama2 checkpoint you have generated or one from [the open source community](https://llama.meta.com/llama-downloads/).
* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`.
* After downloading PyTorch checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively.
* `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}`
* Please refer to the [conversion script](/~https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`.
* Then, using the following command to convert the Llama2 checkpoint into a MaxText compatible unscanned checkpoint.

```bash
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET}
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED}

# For llama2-7b
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED}

# For llama2-13b
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED}
```

Note: For more information about the Llama2 model and checkpoints, see [About Llama2](/~https://github.com/google/maxtext/blob/main/getting_started/Run_Llama2.md).


## Step4: Run the JetStream MaxText server
## Step 4: Run the JetStream MaxText server


### Create model config environment variables for server flags
Expand All @@ -104,8 +104,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=1
export ICI_TENSOR_PARALLELISM=-1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
Expand All @@ -122,17 +122,15 @@ export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=1
export ICI_TENSOR_PARALLELISM=-1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
```

#### Create Llama2-13b environment variables for server flags



* Configure the [flags](#jetstream-maxtext-server-flag-descriptions) passing into the JetStream MaxText server

```bash
Expand All @@ -142,8 +140,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=1
export ICI_TENSOR_PARALLELISM=-1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4
Expand Down Expand Up @@ -187,7 +185,8 @@ python MaxText/maxengine_server.py \
Note: these flags are from [MaxText config](/~https://github.com/google/maxtext/blob/f9e04cdc1eec74a0e648411857c09403c3358461/MaxText/configs/base.yml)


## Step 5: Send test request to JetStream MaxText server
## Step 5: Send a test request to JetStream MaxText server
In a new tab in your terminal, run the following command

```bash
cd ~
Expand All @@ -207,32 +206,100 @@ Response: to be a fan

## Step 6: Run benchmarks with JetStream MaxText server

Note: The JetStream MaxText Server is not running with quantization optimization in Step 3. To get best benchmark results, we need to enable quantization (Please use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache, please add the quantization flags and restart the server as following:
Note: The JetStream MaxText Server commands from Step 4 are not running with any quantization optimizations. To get the best benchmark results, we need to enable quantization for weights and KV cache. To do this, first generate AQT trained or fine-tuned checkpoints. Then, add the quantization flags and restart the server.

### Generating a quantized checkpoint

First, define the path to which the quantized checkpoint
```bash
# Enable int8 quantization for both weights and KV cache
export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-7b-chat
```

There are several different quantization configurations to choose from:

#### int8 DRQ quantized checkpoint
```bash
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
```

#### Weights-only int8 quantized checkpoint
```bash
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8w save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
```

#### Mixed precision weight-only quantized checkpoint
First, update the mixed precision config file (`MaxText/configs/quantization/mp_scale.json`) in MaxText repo to the mixed-precision-config defined below.
```
{
".*/query": {"bits": 4, "scale": 0.8},
".*/key": {"bits": 4, "scale": 0.9},
".*/value": {"bits": 8},
".*/out": {"bits": 4},
".*/wi_0": {"bits": 4},
".*/wo": {"bits": 8}
}
```
Then run the following command:
```bash
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=intmp
quant_cfg_path=configs/quantization/mp_scale.json save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
```

### Restart the server with quantization flags

#### Set flags

Setting base quantization flags
```bash
# To load an int8 DRQcheckpoint
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true
export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH}
export CHECKPOINT_IS_QUANTIZED=True

# To load an int8 weight-only checkpoint
export QUANTIZATION=int8w
export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH}
export CHECKPOINT_IS_QUANTIZED=True

# To load a Mixed-Precision quantized checkpoint
# If using Mixed-Precision mode, make sure to update the mixed precision config file to the same file as used for quantizing the checkpoint (MaxText/configs/quantization/mp_scale.json)
export QUANTIZATION=intmp
export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH}
export CHECKPOINT_IS_QUANTIZED=True
export QUANT_CFG_PATH=configs/quantization/mp_scale.json
```

The KV-cache is quantized to int8 by using the following config params
```bash
export QUANTIZE_KVCACHE=True
```
If you don't want to quantize the KV-cache, set
```bash
export QUANTIZE_KVCACHE=False
```


#### Restart server
```bash
# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance.
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE}
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE}
```

### Benchmarking Gemma-7b
Expand Down Expand Up @@ -261,11 +328,12 @@ python JetStream/benchmarks/benchmark_serving.py \
--request-rate 5 \
--warmup-mode sampled
```
For details, please see /~https://github.com/google/JetStream/blob/main/benchmarks/README.md

### Benchmarking Llama2-\*b
### Benchmarking Llama2

```bash
# Same as Gemma-7b except for the tokenizer (must use a tokenizer that matches your model, which should now be tokenizer.llama2).
# The command is the same as that for the Gemma-7b, except for the tokenizer. Since we need to use a tokenizer that matches the model, it should now be tokenizer.llama2.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
Expand All @@ -276,17 +344,19 @@ python JetStream/benchmarks/benchmark_serving.py \
--request-rate 5 \
--warmup-mode sampled
```
For details, please see /~https://github.com/google/JetStream/blob/main/benchmarks/README.md

## Clean Up

```bash
# Clean up gcs buckets.
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up python virtual environment
rm -rf .env
```
12 changes: 4 additions & 8 deletions jetstream/tools/maxtext/model_ckpt_conversion.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,21 @@ export MODEL=$1
export MODEL_VARIATION=$2
export MODEL_NAME=${MODEL}-${MODEL_VARIATION}

# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET
# Please use separate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
# Point these variables to a GCS bucket that you created.
# An example of CHKPT_BUCKET could be: gs://${USER}-maxtext/chkpt/${MODEL}/${MODEL_VARIATION}
export CHKPT_BUCKET=$3
export MODEL_BUCKET=gs://${USER}-maxtext
export MODEL_BUCKET=$4

# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run.
export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs

# Point `DATASET_PATH` to the GCS bucket where you have your training data.
export DATASET_PATH=gs://${USER}-maxtext-dataset
# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run, specifically the unscanned checkpoint.
export BASE_OUTPUT_DIRECTORY=$5

export BUCKET_LOCATION=US

# Create three GCS buckets for the demo.
gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true
gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true
gcloud storage buckets create ${DATASET_PATH} --location=${BUCKET_LOCATION} || true

# Convert model checkpoints to MaxText compatible checkpoints.
if [ "$MODEL" == "gemma" ]; then
Expand Down
Loading