-
Notifications
You must be signed in to change notification settings - Fork 295
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Fixes #1443 Notable things I had to do - ezyang hopefully these are OK * Since LLAMA requires special permission to download weights and checkpoints and the tokenizer, I went ahead with random checkpoints and random tokenizer - not sure CI qualifies as a valid research endeavour * I removed the dependency on fairscale so had to make a few adjustments like turning ParallelLinear into Linear or ParallelEmbedding into Embedding and things mostly seem to work fine. And added bonus is you can run the example on a single machine * Inference in the code using torch inference mode, I removed it since it has a weird interaction with torch.compile * The open source LLAMA repo is inference only so there is no training support in this script Some other things I can improve in an another PR * Better configuration including sequence length and batching * Reenabling distributed support with FAIRSCALE I can run the code now ``` (bench) ubuntu@ip-172-31-39-186:~/benchmark$ python run.py llama -d cuda Running eval method from llama on cuda in eager mode with input batch size 32. GPU Time: 10.006 milliseconds CPU Total Wall Time: 10.045 milliseconds ``` Pull Request resolved: #1446 Reviewed By: msaroufim Differential Revision: D43960031 Pulled By: xuzhao9 fbshipit-source-id: 05d58ff0c92080542a16433ab3eb550322525152
- Loading branch information
1 parent
d8e5325
commit c78f1f3
Showing
9 changed files
with
416 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
||
|
||
|
||
from ...util.model import BenchmarkModel | ||
from torchbenchmark.tasks import NLP | ||
import torch | ||
from .model import ModelArgs, Transformer | ||
import torch | ||
|
||
class Model(BenchmarkModel): | ||
task = NLP.LANGUAGE_MODELING | ||
|
||
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): | ||
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) | ||
self.model_args = ModelArgs(vocab_size=32) | ||
torch.set_default_device(self.device) | ||
self.model = Transformer(self.model_args).to(self.device) | ||
self.example_inputs = (torch.tensor([[1, 1], [1,1]], dtype=torch.int).to(self.device), 1) | ||
|
||
|
||
def get_module(self): | ||
return self.model, self.example_inputs | ||
|
||
def train(self): | ||
error_msg = """ | ||
As of March 6, 2023 | ||
The weights for this model are not publicly available and require a valid research reason to use | ||
The publicly available github repo is inference only | ||
/~https://github.com/facebookresearch/llama | ||
""" | ||
return NotImplementedError(error_msg) | ||
|
||
def eval(self): | ||
self.model.eval() | ||
with torch.no_grad(): | ||
out=self.model(*self.example_inputs) | ||
return (out,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
||
from typing import List | ||
|
||
import torch | ||
|
||
from .tokenizer import Tokenizer | ||
from .model import Transformer | ||
|
||
|
||
class LLaMA: | ||
def __init__(self, model: Transformer, tokenizer: Tokenizer): | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
|
||
def generate( | ||
self, | ||
prompts: List[str], | ||
max_gen_len: int, | ||
temperature: float = 0.8, | ||
top_p: float = 0.95, | ||
) -> List[str]: | ||
bsz = len(prompts) | ||
params = self.model.params | ||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) | ||
|
||
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] | ||
|
||
min_prompt_size = min([len(t) for t in prompt_tokens]) | ||
max_prompt_size = max([len(t) for t in prompt_tokens]) | ||
|
||
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) | ||
|
||
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() | ||
for k, t in enumerate(prompt_tokens): | ||
tokens[k, : len(t)] = torch.tensor(t).long() | ||
input_text_mask = tokens != self.tokenizer.pad_id | ||
start_pos = min_prompt_size | ||
prev_pos = 0 | ||
for cur_pos in range(start_pos, total_len): | ||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) | ||
if temperature > 0: | ||
probs = torch.softmax(logits / temperature, dim=-1) | ||
next_token = sample_top_p(probs, top_p) | ||
else: | ||
next_token = torch.argmax(logits, dim=-1) | ||
next_token = next_token.reshape(-1) | ||
# only replace token if prompt has already been generated | ||
next_token = torch.where( | ||
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token | ||
) | ||
tokens[:, cur_pos] = next_token | ||
prev_pos = cur_pos | ||
|
||
decoded = [] | ||
for i, t in enumerate(tokens.tolist()): | ||
# cut to max gen len | ||
t = t[: len(prompt_tokens[i]) + max_gen_len] | ||
# cut to eos tok if any | ||
try: | ||
t = t[: t.index(self.tokenizer.eos_id)] | ||
except ValueError: | ||
pass | ||
decoded.append(self.tokenizer.decode(t)) | ||
return decoded | ||
|
||
|
||
def sample_top_p(probs, p): | ||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | ||
probs_sum = torch.cumsum(probs_sort, dim=-1) | ||
mask = probs_sum - probs_sort > p | ||
probs_sort[mask] = 0.0 | ||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | ||
next_token = torch.multinomial(probs_sort, num_samples=1) | ||
next_token = torch.gather(probs_idx, -1, next_token) | ||
return next_token |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import subprocess | ||
import sys | ||
|
||
def pip_install_requirements(): | ||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) | ||
|
||
if __name__ == '__main__': | ||
pip_install_requirements() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
devices: | ||
NVIDIA A100-SXM4-40GB: | ||
eval_batch_size: 1024 | ||
eval_benchmark: false | ||
eval_deterministic: false | ||
eval_nograd: true | ||
train_benchmark: false | ||
train_deterministic: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
||
from typing import Optional, Tuple | ||
from dataclasses import dataclass | ||
import math | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
@dataclass | ||
class ModelArgs: | ||
dim: int = 512 | ||
n_layers: int = 8 | ||
n_heads: int = 8 | ||
vocab_size: int = -1 | ||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 | ||
norm_eps: float = 1e-5 | ||
|
||
max_batch_size: int = 32 | ||
max_seq_len: int = 1024 | ||
|
||
|
||
class RMSNorm(torch.nn.Module): | ||
def __init__(self, dim: int, eps: float = 1e-6): | ||
super().__init__() | ||
self.eps = eps | ||
self.weight = nn.Parameter(torch.ones(dim)) | ||
|
||
def _norm(self, x): | ||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | ||
|
||
def forward(self, x): | ||
output = self._norm(x.float()).type_as(x) | ||
return output * self.weight | ||
|
||
|
||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | ||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | ||
t = torch.arange(end, device=freqs.device) # type: ignore | ||
freqs = torch.outer(t, freqs).float() # type: ignore | ||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | ||
return freqs_cis | ||
|
||
|
||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | ||
ndim = x.ndim | ||
assert 0 <= 1 < ndim | ||
assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | ||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | ||
return freqs_cis.view(*shape) | ||
|
||
|
||
def apply_rotary_emb( | ||
xq: torch.Tensor, | ||
xk: torch.Tensor, | ||
freqs_cis: torch.Tensor, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | ||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | ||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | ||
return xq_out.type_as(xq), xk_out.type_as(xk) | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__(self, args: ModelArgs): | ||
super().__init__() | ||
|
||
self.n_local_heads = args.n_heads # Basically we just assume world size of 1 // fs_init.get_model_parallel_world_size() | ||
self.head_dim = args.dim // args.n_heads | ||
|
||
self.wq = nn.Linear( | ||
args.dim, | ||
args.n_heads * self.head_dim, | ||
bias=False, | ||
|
||
) | ||
self.wk = nn.Linear( | ||
args.dim, | ||
args.n_heads * self.head_dim, | ||
bias=False, | ||
|
||
) | ||
self.wv = nn.Linear( | ||
args.dim, | ||
args.n_heads * self.head_dim, | ||
bias=False, | ||
|
||
) | ||
self.wo = nn.Linear( | ||
args.n_heads * self.head_dim, | ||
args.dim, | ||
bias=False, | ||
|
||
) | ||
|
||
self.cache_k = torch.zeros( | ||
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) | ||
).cuda() | ||
self.cache_v = torch.zeros( | ||
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) | ||
).cuda() | ||
|
||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
bsz, seqlen, _ = x.shape | ||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | ||
|
||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
|
||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) | ||
|
||
self.cache_k = self.cache_k.to(xq) | ||
self.cache_v = self.cache_v.to(xq) | ||
|
||
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk | ||
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv | ||
|
||
keys = self.cache_k[:bsz, : start_pos + seqlen] | ||
values = self.cache_v[:bsz, : start_pos + seqlen] | ||
|
||
xq = xq.transpose(1, 2) | ||
keys = keys.transpose(1, 2) | ||
values = values.transpose(1, 2) | ||
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) | ||
|
||
# TODO: RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 3 | ||
# if mask is not None: | ||
# scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) | ||
scores = F.softmax(scores.float(), dim=-1).type_as(xq) | ||
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) | ||
output = output.transpose( | ||
1, 2 | ||
).contiguous().view(bsz, seqlen, -1) | ||
|
||
return self.wo(output) | ||
|
||
|
||
class FeedForward(nn.Module): | ||
def __init__( | ||
self, | ||
dim: int, | ||
hidden_dim: int, | ||
multiple_of: int, | ||
): | ||
super().__init__() | ||
hidden_dim = int(2 * hidden_dim / 3) | ||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | ||
|
||
self.w1 = nn.Linear( | ||
dim, hidden_dim, bias=False | ||
) | ||
self.w2 = nn.Linear( | ||
hidden_dim, dim, bias=False | ||
) | ||
self.w3 = nn.Linear( | ||
dim, hidden_dim, bias=False | ||
) | ||
|
||
def forward(self, x): | ||
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||
|
||
|
||
class TransformerBlock(nn.Module): | ||
def __init__(self, layer_id: int, args: ModelArgs): | ||
super().__init__() | ||
self.n_heads = args.n_heads | ||
self.dim = args.dim | ||
self.head_dim = args.dim // args.n_heads | ||
self.attention = Attention(args) | ||
self.feed_forward = FeedForward( | ||
dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of | ||
) | ||
self.layer_id = layer_id | ||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
|
||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask) | ||
out = h + self.feed_forward.forward(self.ffn_norm(h)) | ||
return out | ||
|
||
|
||
class Transformer(nn.Module): | ||
def __init__(self, params: ModelArgs): | ||
super().__init__() | ||
self.params = params | ||
self.vocab_size = params.vocab_size | ||
self.n_layers = params.n_layers | ||
|
||
self.tok_embeddings = nn.Embedding( | ||
params.vocab_size + 1, params.dim, | ||
) | ||
|
||
|
||
self.layers = torch.nn.ModuleList() | ||
for layer_id in range(params.n_layers): | ||
self.layers.append(TransformerBlock(layer_id, params)) | ||
|
||
self.norm = RMSNorm(params.dim, eps=params.norm_eps) | ||
self.output = nn.Linear( | ||
params.dim, params.vocab_size + 1, bias=False | ||
) | ||
|
||
self.freqs_cis = precompute_freqs_cis( | ||
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 | ||
) | ||
|
||
def forward(self, tokens: torch.Tensor, start_pos: int): | ||
_ , seqlen = tokens.shape | ||
|
||
h = self.tok_embeddings(tokens) | ||
|
||
self.freqs_cis = self.freqs_cis.to(h.device) | ||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] | ||
|
||
mask = None | ||
|
||
if seqlen > 1: | ||
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device) | ||
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) | ||
|
||
for layer in self.layers: | ||
h = layer(h, start_pos, freqs_cis, mask) | ||
h = self.norm(h) | ||
output = self.output(h[:, -1, :]) # only compute last logits | ||
return output.float() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
origin /~https://github.com/facebookresearch/llama |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
sentencepiece |
Oops, something went wrong.