From 7b4177840c4b290efeac3c638271a143de12265c Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Mon, 12 Aug 2024 18:01:46 -0700 Subject: [PATCH] Process separate q/k/v weights in MHA converter (#1020) Summary: Pull Request resolved: /~https://github.com/facebookincubator/AITemplate/pull/1020 ATT. The converter was not ready for the `not self._qkv_same_embed_dim` case [here](/~https://github.com/pytorch/pytorch/blob/80ed3e9ccdaab20814b4156611a19043aaaaef03/torch/nn/modules/activation.py#L1074) with separate q/k/v weights. Here we cover this case. Intenral: This causes a failure in the AIT lowering of the IGCTR MC model. See the post: https://fb.workplace.com/groups/gpuinference/permalink/2872581106223872/ . Reviewed By: ColinPeppler Differential Revision: D61155566 fbshipit-source-id: 98ba4c4150a036268ec8bcbe4f6b5aa7934374d2 --- .../converters/ait_module_converters.py | 32 ++++++++++++++++- .../test_ait_multihead_attention.py | 36 ++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/fx2ait/fx2ait/converters/ait_module_converters.py b/fx2ait/fx2ait/converters/ait_module_converters.py index fad6920b4..8a063804e 100644 --- a/fx2ait/fx2ait/converters/ait_module_converters.py +++ b/fx2ait/fx2ait/converters/ait_module_converters.py @@ -63,6 +63,7 @@ def multi_head_attention_module( ) # Bind constant tensor for MHA module + q_w, k_w, v_w = None, None, None qkv_weight, qkv_bias = None, None for k, v in submod.named_parameters(): ait_data = _TorchConstantTensorData(v.data.contiguous().cuda().half()) @@ -81,6 +82,27 @@ def multi_head_attention_module( name=make_str_ait_friendly(f"{target}.{k}"), ) qkv_bias._bind_data(ait_data) + elif k == "q_proj_weight": + q_w = Tensor( + shape=v.shape, + dtype="float16", + name=make_str_ait_friendly(f"{target}.{k}"), + ) + q_w._bind_data(ait_data) + elif k == "k_proj_weight": + k_w = Tensor( + shape=v.shape, + dtype="float16", + name=make_str_ait_friendly(f"{target}.{k}"), + ) + k_w._bind_data(ait_data) + elif k == "v_proj_weight": + v_w = Tensor( + shape=v.shape, + dtype="float16", + name=make_str_ait_friendly(f"{target}.{k}"), + ) + v_w._bind_data(ait_data) elif "out_proj" in k: if "weight" in k: tensor = attn.proj.weight.tensor() @@ -90,7 +112,15 @@ def multi_head_attention_module( tensor._bind_data(ait_data) # Swap out qkv tensor used by nn.CrossAttention. - q_w, k_w, v_w = chunk()(qkv_weight, 3) + if qkv_weight is not None: + assert q_w is None + assert k_w is None + assert v_w is None + q_w, k_w, v_w = chunk()(qkv_weight, 3) + else: + assert q_w is not None + assert k_w is not None + assert v_w is not None q_b, k_b, v_b = chunk()(qkv_bias, 3) attn.proj_q.weight._tensor = q_w diff --git a/fx2ait/fx2ait/test/converters/converters_module/test_ait_multihead_attention.py b/fx2ait/fx2ait/test/converters/converters_module/test_ait_multihead_attention.py index 4641745ed..11579a35d 100644 --- a/fx2ait/fx2ait/test/converters/converters_module/test_ait_multihead_attention.py +++ b/fx2ait/fx2ait/test/converters/converters_module/test_ait_multihead_attention.py @@ -18,7 +18,7 @@ class TestMultiHeadAttentionConverter(AITTestCase): - def test_multihead_attention_cross_attenytion(self): + def test_multihead_attention_cross_attention(self): class TestModule(torch.nn.Module): def __init__(self, dim, nheads): super().__init__() @@ -77,3 +77,37 @@ def forward(self, x): expected_ops={torch.nn.MultiheadAttention}, leaf_module=torch.nn.MultiheadAttention, ) + + def test_multihead_attention_different_kv_dims(self): + class TestModule(torch.nn.Module): + def __init__(self, qdim, kdim, vdim, nheads): + super().__init__() + self.attn = torch.nn.MultiheadAttention( + embed_dim=qdim, + num_heads=nheads, + batch_first=True, + kdim=kdim, + vdim=vdim, + ) + + def forward(self, q, k, v): + return self.attn(query=q, key=k, value=v) + + batch_size = 2 + seqlen = 4 + qdim = 512 + kdim = 128 + vdim = 128 + num_heads = 8 + + q = torch.ones(batch_size, seqlen, qdim).cuda().half() + k = torch.ones(batch_size, seqlen, kdim).cuda().half() + v = torch.ones(batch_size, seqlen, vdim).cuda().half() + model = TestModule(qdim, kdim, vdim, num_heads).eval().half().cuda() + + self.run_test( + model, + [q, k, v], + expected_ops={torch.nn.MultiheadAttention}, + leaf_module=torch.nn.MultiheadAttention, + )