Skip to content

Commit

Permalink
Fix SegformerForImageClassification (#15895)
Browse files Browse the repository at this point in the history
* Fix reshape

* Apply suggestion from code review

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
  • Loading branch information
NielsRogge and Niels Rogge authored Mar 2, 2022
1 parent 130b987 commit 89be34c
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/transformers/models/segformer/modeling_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,11 @@ def forward(

sequence_output = outputs[0]

# reshape last hidden states to (batch_size, height*width, hidden_size)
# convert last hidden states to (batch_size, height*width, hidden_size)
batch_size = sequence_output.shape[0]
if self.config.reshape_last_stage:
# (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
sequence_output = sequence_output.permute(0, 2, 3, 1)
sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1])

# global average pooling
Expand Down Expand Up @@ -660,10 +663,19 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)

self.config = config

def forward(self, encoder_hidden_states):
batch_size, _, _, _ = encoder_hidden_states[-1].shape
batch_size = encoder_hidden_states[-1].shape[0]

all_hidden_states = ()
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3:
height = width = int(math.sqrt(encoder_hidden_state.shape[-1]))
encoder_hidden_state = (
encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
)

# unify channel dimension
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
encoder_hidden_state = mlp(encoder_hidden_state)
Expand Down

0 comments on commit 89be34c

Please sign in to comment.