Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MaxViT model #912

Closed
innat opened this issue Oct 12, 2022 · 29 comments
Closed

Add MaxViT model #912

innat opened this issue Oct 12, 2022 · 29 comments

Comments

@innat
Copy link
Contributor

innat commented Oct 12, 2022

Short Description

FeuQq4CUoAkGkMe

Multi-Axis Vision Transformer: MaxViT is a family of hybrid (CNN + ViT) image classification models, that achieves better performances across the board for both parameter and FLOPs efficiency than both SoTA ConvNets and Transformers.

Papers

https://arxiv.org/abs/2204.01697

Existing Implementations

Official Implementation:
Goolge, TensorFlow 2 (Keras). /~https://github.com/google-research/maxvit

cc. @Yinxiaoli @vztu

@bhack
Copy link
Contributor

bhack commented Oct 12, 2022

Quite related to #911

@ayulockin
Copy link

Hey all I can work on this. :)

@DavidLandup0
Copy link
Contributor

I'd gladly also port this from the official repo to here. :)
If someone could assign me to it, I'd get to it as soon as the Dice and Jaccard coefficients are done.

@tanzhenyu
Copy link
Contributor

Hey all I can work on this. :)

@ayulockin @innat Ideally we would like to have SwinTransformer first:
#671

@innat
Copy link
Contributor Author

innat commented Oct 13, 2022

@tanzhenyu
I think vanila ViT should be first, it's like VGG for transformer 😄

In the mean time, I think it's also ok to start working on the basic component like window partition, grid attention, trail-dense etc. cc @ayulockin @DavidLandup0

@tanzhenyu
Copy link
Contributor

@tanzhenyu I think vanila ViT should be first, it's like VGG for transformer 😄

In the mean time, I think it's also ok to start working on the basic component like window partition, grid attention, trail-dense etc. cc @ayulockin @DavidLandup0

#668

@DavidLandup0
Copy link
Contributor

Creating a pull request later today with layers for patching, mlp heads, linear projections, etc. We can use those to build a ViT and then extend it to Swin and other transformers for vision. A rough draft for ViT will be coming in with the basic layers. Would you prefer a PR for components, and then a PR for ViT on a different branch instead? @tanzhenyu @innat

@bhack
Copy link
Contributor

bhack commented Oct 13, 2022

Not that we need to do the same but at the same time I will take a look to also at the modularization organized in the quite popular Huggingface Transformers API
/~https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_tf_vit.py

@tanzhenyu
Copy link
Contributor

Creating a pull request later today with layers for patching, mlp heads, linear projections, etc. We can use those to build a ViT and then extend it to Swin and other transformers for vision. A rough draft for ViT will be coming in with the basic layers. Would you prefer a PR for components, and then a PR for ViT on a different branch instead? @tanzhenyu @innat

Given we don't anticipate the need to expose components such as linear projections as public APIs, either creating a single PR or multiple PRs sounds good to me.
In this case, it really depends if you want it to be re-used at Swin and others. Modularity is the key. If that's what you want, would you mind coming up with a basic design to show how those components can fit into other models?

@DavidLandup0
Copy link
Contributor

DavidLandup0 commented Oct 13, 2022

Sure! I'm packaging them into one PR as a draft overview, just to check whether the general structure is okay. It'd be unwise to work more on it if major changes need to be done. I'm testing out a rough idea and will push it in later for a cursory look/review :)

The idea was to build blocks that we can reuse for most transformer-based models. Currently, building a ViT with it looks like this:

inputs = utils.parse_model_inputs(input_shape, input_tensor)
 x = inputs

 if include_rescaling:
     x = layers.Rescaling(1 / 255.0)(x)

 patches = keras_cv.layers.Patching(patch_size)(x)
 encoded_patches = keras_cv.transformers.PatchEncoder(num_patches, project_dim)(patches)


    for _ in range(transformer_layer_num):
        x = keras_cv.transformers.TransformerEncoder()(encoded_patches)

 representation = layers.LayerNormalization(epsilon=1e-6)(x)
 representation = layers.Flatten()(representation)
 representation = layers.Dropout(0.5)(representation)

 features =  mlp_ffn(representation, hidden_units=head_units, dropout_rate=0.5)
 logits = layers.Dense(num_classes)(features)
 model = keras.Model(inputs=inputs, outputs=logits)

@ayulockin
Copy link

In the mean time, I think it's also ok to start working on the basic component like window partition, grid attention, trail-dense etc. cc @ayulockin @DavidLandup0

I think adding basic components as you mentioned should be the way to go. KerasCV's aim is to provide components for industrial adaption of research. I think instead of focusing on models (ViT, Swin, etc) we should scope the transformers for vision such that we can build fundamentals blocks.

@bhack
Copy link
Contributor

bhack commented Oct 13, 2022

I think adding basic components as you mentioned should be the way to go. KerasCV's aim is to provide components for industrial adaption of research. I think instead of focusing on models (ViT, Swin, etc) we should scope the transformers for vision such that we can build fundamentals blocks.

It is why I've suggest to explore Huggingface transformer modules.

Probably it is not the best modularization that we could achieve but at least they have already accumulate a quite relevant list of transformer archs on the library.

I don't know if its is production level or not but at least it is partially validated by the number of models.

@tanzhenyu
Copy link
Contributor

Sure! I'm packaging them into one PR as a draft overview, just to check whether the general structure is okay. It'd be unwise to work more on it if major changes need to be done. I'm testing out a rough idea and will push it in ~30min for a cursory look/review :)

The idea was to build blocks that we can reuse for most transformer-based models. Currently, building a ViT with it looks like this:

inputs = utils.parse_model_inputs(input_shape, input_tensor)
 x = inputs

 if include_rescaling:
     x = layers.Rescaling(1 / 255.0)(x)

 patches = keras_cv.layers.Patching(patch_size)(x)
 encoded_patches = keras_cv.transformers.PatchEncoder(num_patches, project_dim)(patches)


    for _ in range(transformer_layer_num):
        x = keras_cv.transformers.TransformerEncoder()(encoded_patches)

 representation = layers.LayerNormalization(epsilon=1e-6)(x)
 representation = layers.Flatten()(representation)
 representation = layers.Dropout(0.5)(representation)

 features =  mlp_ffn(representation, hidden_units=head_units, dropout_rate=0.5)
 logits = layers.Dense(num_classes)(features)
 model = keras.Model(inputs=inputs, outputs=logits)

This seems to be concise. My only question here is whether the TransformerEncoder is implemented in a different way than "normal" transformer encoders, or the same way? We might want to consider bring this to core Keras instead of KerasCV if it's the same way.

And yes I agree with @bhack and @ayulockin that we should take a look at HF's implementation and make sure we're providing enough modularization.

@bhack
Copy link
Contributor

bhack commented Oct 13, 2022

We might want to consider bring this to core Keras instead of KerasCV if it's the same way

Yes this is another very important point already discussed to minimize (future?) duplications with Keras-nlp

@tanzhenyu tanzhenyu added the wip working in progress from KerasCV team label Oct 27, 2022
@tanzhenyu tanzhenyu added stat:contributions welcome and removed wip working in progress from KerasCV team labels Dec 15, 2022
@DavidLandup0
Copy link
Contributor

As ViTs are finished - I'll be working on this one now ;)
If anyone wants to collab, let me know. (@ayulockin wanted to work on this a while back)

@ayulockin
Copy link

Hey, @DavidLandup0, I would love to collaborate on this with you. :) I was waiting for the ViT to be added so I could build on top of it from a design point of view. Since you have worked on it, collaborating with you would be a great learning experience. :)

@innat
Copy link
Contributor Author

innat commented Dec 15, 2022

@ayulockin just to inform, MAXIM is welcomed too. Most of the official code (jax) + weight was ported to keras, here.

@DavidLandup0
Copy link
Contributor

If nobody else signs up for it by the time MaxViT is done, I'd gladly hop onto MAXIM too :)

@DavidLandup0
Copy link
Contributor

Since MaxViT uses MBConvs, which we have in EfficientNets, and which originated in MobileNets - we'll have three architectures reusing them same blocks. Additionally, having them as a layer would let users try to build networks with them themselves for edge/mobile applications.

I think we should have MBConv as a standalone layer.

Can I separate it into a layer and refactor EfficientNets in preparation for MaxViT? @tanzhenyu @LukeWood @bhack

@IMvision12
Copy link
Contributor

IMvision12 commented Dec 15, 2022

I can work on MAXIM !!

@tanzhenyu
Copy link
Contributor

Since MaxViT uses MBConvs, which we have in EfficientNets, and which originated in MobileNets - we'll have three architectures reusing them same blocks. Additionally, having them as a layer would let users try to build networks with them themselves for edge/mobile applications.

I think we should have MBConv as a standalone layer.

Can I separate it into a layer and refactor EfficientNets in preparation for MaxViT? @tanzhenyu @LukeWood @bhack

Yep, it'd be great to reuse both MBConv and SE

@DavidLandup0
Copy link
Contributor

Done in new PR :)
#1146

@tanzhenyu
Copy link
Contributor

I can work on MAXIM !!

Go ahead!

@ayulockin
Copy link

Here is a quick update on the work done so far:

Work done in collaboration with @DavidLandup0 :)

We have almost all the components - WindowPartition, UnWindowPartition, GridPartition, UnGridPartition and RelativeMultiHeadAttention done.

We have stacked them together to build a barebone MaxViTBlock, and the input and output signatures match the official implementation. We will package it in a class and create MaxViT variants. Will send over a PR once done. :)

@DavidLandup0, do you have anything more to add?

cc: @innat @bhack @tanzhenyu

@DavidLandup0
Copy link
Contributor

Thanks for tagging and awesome work on RelativeMultiHeadAttention! Question for the Keras team - do we want to make RelativeMultiHeadAttention part of core Keras? MHA already is, and the relative variant is general enough for it, IMO.

Since we should package the components for review first, it's enough to have a rough model for the first PR to prove that they work, and assess their usage. I'll do the MaxViTTransformerEncoder and we can open the components PR.

It'd be a good idea to see if we can generalize the existing transformer encoder to be used between ViTs and MaxViTs since they're not too different (and allow the type of multihead attention to be changed). The main counter argument is that it already has quite a few arguments so having a general encoder with many might not be very user friendly.

Thoughts?

@bhack
Copy link
Contributor

bhack commented Dec 29, 2022

The main counter argument is that it already has quite a few arguments so having a general encoder with many might not be very user friendly.

Generally this could be an indirect signal that it could require a base class.

@DavidLandup0
Copy link
Contributor

DavidLandup0 commented Dec 29, 2022

For reference, this is the constructor:

def __init__(
        self,
        project_dim,
        num_heads,
        mlp_dim,
        mlp_dropout=0.1,
        attention_dropout=0.1,
        activation=tf.keras.activations.gelu,
        layer_norm_epsilon=1e-06,
        attention_type='mha',
        **kwargs,
    ):

Though, because of the defaults, usage can be as simple as:

keras_cv.layers.TransformerEncoder(project_dim=project_dim,
                                           mlp_dim = mlp_dim,
                                           num_heads=num_heads)(encoded_patches)

Now - I remember KerasNLP having this same issue. We might not be able to have a fully general TransformerEncoder for all cases, so it might be better to do them separately?

In the case of MaxViT, it's one extra arg, that simply defines:

        if attention_type == 'mha':
            attention_layer = layers.MultiHeadAttention
        elif attention_type == 'relmha':
            attention_layer = layers.RelativeMultiHeadAttention

So it's a small change. The question is mainly for work down the line when we might need to support more options.

@ayulockin
Copy link

I am in favour of a separate TransformerEncoder. It allows for speedy implementation since vision transformers rapidly evolve.

The counterargument is that we implement a handful of vision transformers and then try to build a unified transformer encoder by introducing a base class.

@tanzhenyu
Copy link
Contributor

Here is a quick update on the work done so far:

Work done in collaboration with @DavidLandup0 :)

We have almost all the components - WindowPartition, UnWindowPartition, GridPartition, UnGridPartition and RelativeMultiHeadAttention done.

We have stacked them together to build a barebone MaxViTBlock, and the input and output signatures match the official implementation. We will package it in a class and create MaxViT variants. Will send over a PR once done. :)

@DavidLandup0, do you have anything more to add?

cc: @innat @bhack @tanzhenyu

Great progress! The breakdown of those components sounds good to me. @vztu @Yinxiaoli can you comment here?

Re David's question -- I think it'd be nice to have a transformer encoder that accept different attention mechanisms, though we don't have plan to move relative attention to core keras yet -- maybe later, given there are so many different attentions out there. If MaxVit can re-use the encoder that'd be great, the core value of KCV is always to provide generic components.

@ayulockin ayulockin mentioned this issue Jan 24, 2023
4 tasks
@innat innat closed this as completed Aug 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants