Skip to content

Commit

Permalink
Add LLAMA (#1446)
Browse files Browse the repository at this point in the history
Summary:
Fixes #1443

Notable things I had to do - ezyang hopefully these are OK
* Since LLAMA requires special permission to download weights and checkpoints and the tokenizer, I went ahead with random checkpoints and random tokenizer - not sure CI qualifies as a valid research endeavour
* I removed the dependency on fairscale so had to make a few adjustments like turning ParallelLinear into Linear or ParallelEmbedding into Embedding and things mostly seem to work fine. And added bonus is you can run the example on a single machine
* Inference in the code using torch inference mode, I removed it since it has a weird interaction with torch.compile
* The open source LLAMA repo is inference only so there is no training support in this script

Some other things I can improve in an another PR
* Better configuration including sequence length and batching
* Reenabling distributed support with FAIRSCALE

I can run the code now

```
(bench) ubuntu@ip-172-31-39-186:~/benchmark$ python run.py llama -d cuda
Running eval method from llama on cuda in eager mode with input batch size 32.
GPU Time:             10.006 milliseconds
CPU Total Wall Time:  10.045 milliseconds
```

Pull Request resolved: #1446

Reviewed By: msaroufim

Differential Revision: D43960031

Pulled By: xuzhao9

fbshipit-source-id: 05d58ff0c92080542a16433ab3eb550322525152
  • Loading branch information
msaroufim authored and facebook-github-bot committed Mar 11, 2023
1 parent d8e5325 commit c78f1f3
Show file tree
Hide file tree
Showing 9 changed files with 416 additions and 0 deletions.
10 changes: 10 additions & 0 deletions torchbenchmark/models/ADDING_MODELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ Some of the APIs are optional, and you can raise NotImplemented if a particular

Take care to set the random seed like [here](/~https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/Background_Matting/__init__.py#L20), to ensure your model runs the same way each time it's benchmarked.


#### A minimal new model addition
A bare miminum example you can follow is /~https://github.com/pytorch/benchmark/tree/main/torchbenchmark/models/phlippe_resnet

The functions you specifically need to implement are
1. `__init__()` which is responsible for initalizing your `nn.Module`
2. `get_module()` which is responsible for returning the initialized `nn.Module` and an example input
3. `train()` which is a training loop, you can return a `NotImplementedError()` if your example is inference only
4. `eval()` which showcases a simple inference

### Preparing install.py and dependencies
Simply put, install.py should be a one stop shop to install all the dependencies
for your model, __except torch, torchvision, torchtext__ which should be assumed to
Expand Down
39 changes: 39 additions & 0 deletions torchbenchmark/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.



from ...util.model import BenchmarkModel
from torchbenchmark.tasks import NLP
import torch
from .model import ModelArgs, Transformer
import torch

class Model(BenchmarkModel):
task = NLP.LANGUAGE_MODELING

def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
self.model_args = ModelArgs(vocab_size=32)
torch.set_default_device(self.device)
self.model = Transformer(self.model_args).to(self.device)
self.example_inputs = (torch.tensor([[1, 1], [1,1]], dtype=torch.int).to(self.device), 1)


def get_module(self):
return self.model, self.example_inputs

def train(self):
error_msg = """
As of March 6, 2023
The weights for this model are not publicly available and require a valid research reason to use
The publicly available github repo is inference only
/~https://github.com/facebookresearch/llama
"""
return NotImplementedError(error_msg)

def eval(self):
self.model.eval()
with torch.no_grad():
out=self.model(*self.example_inputs)
return (out,)
77 changes: 77 additions & 0 deletions torchbenchmark/models/llama/generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from typing import List

import torch

from .tokenizer import Tokenizer
from .model import Transformer


class LLaMA:
def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer

def generate(
self,
prompts: List[str],
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
) -> List[str]:
bsz = len(prompts)
params = self.model.params
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])

total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id
start_pos = min_prompt_size
prev_pos = 0
for cur_pos in range(start_pos, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
prev_pos = cur_pos

decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len
t = t[: len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any
try:
t = t[: t.index(self.tokenizer.eos_id)]
except ValueError:
pass
decoded.append(self.tokenizer.decode(t))
return decoded


def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
8 changes: 8 additions & 0 deletions torchbenchmark/models/llama/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import subprocess
import sys

def pip_install_requirements():
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])

if __name__ == '__main__':
pip_install_requirements()
8 changes: 8 additions & 0 deletions torchbenchmark/models/llama/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 1024
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
232 changes: 232 additions & 0 deletions torchbenchmark/models/llama/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from typing import Optional, Tuple
from dataclasses import dataclass
import math

import torch
from torch import nn
import torch.nn.functional as F


@dataclass
class ModelArgs:
dim: int = 512
n_layers: int = 8
n_heads: int = 8
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
norm_eps: float = 1e-5

max_batch_size: int = 32
max_seq_len: int = 1024


class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

self.n_local_heads = args.n_heads # Basically we just assume world size of 1 // fs_init.get_model_parallel_world_size()
self.head_dim = args.dim // args.n_heads

self.wq = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False,

)
self.wk = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False,

)
self.wv = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False,

)
self.wo = nn.Linear(
args.n_heads * self.head_dim,
args.dim,
bias=False,

)

self.cache_k = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()
self.cache_v = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)

self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

# TODO: RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 3
# if mask is not None:
# scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
output = output.transpose(
1, 2
).contiguous().view(bsz, seqlen, -1)

return self.wo(output)


class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = nn.Linear(
dim, hidden_dim, bias=False
)
self.w2 = nn.Linear(
hidden_dim, dim, bias=False
)
self.w3 = nn.Linear(
dim, hidden_dim, bias=False
)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out


class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers

self.tok_embeddings = nn.Embedding(
params.vocab_size + 1, params.dim,
)


self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))

self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(
params.dim, params.vocab_size + 1, bias=False
)

self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
)

def forward(self, tokens: torch.Tensor, start_pos: int):
_ , seqlen = tokens.shape

h = self.tok_embeddings(tokens)

self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

mask = None

if seqlen > 1:
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h[:, -1, :]) # only compute last logits
return output.float()
1 change: 1 addition & 0 deletions torchbenchmark/models/llama/origin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
origin /~https://github.com/facebookresearch/llama
1 change: 1 addition & 0 deletions torchbenchmark/models/llama/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sentencepiece
Loading

0 comments on commit c78f1f3

Please sign in to comment.