diff --git a/README.md b/README.md index 8495be9..52519cf 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Unofficial PyTorch implementation of [MelGAN vocoder](https://arxiv.org/abs/1910 - MelGAN is lighter, faster, and better at generalizing to unseen speakers than [WaveGlow](/~https://github.com/NVIDIA/waveglow). - This repository use identical mel-spectrogram function from [NVIDIA/tacotron2](/~https://github.com/NVIDIA/tacotron2), so this can be directly used to convert output from NVIDIA's tacotron2 into raw-audio. -- TODO: Planning to publish pretrained model via [PyTorch Hub](https://pytorch.org/hub). +- Pretrained model on LJSpeech-1.1 via [PyTorch Hub](https://pytorch.org/hub). ![](./assets/gd.png) @@ -27,6 +27,24 @@ pip install -r requirements.txt - `python trainer.py -c [config yaml file] -n [name of the run]` - `tensorboard --logdir logs/` +## Pretrained model + +Try with Google Colab: + +```python +import torch +vocoder = torch.hub.load('seungwonpark/melgan', 'melgan') +vocoder.eval() +mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here + +if torch.cuda.is_available(): + vocoder = vocoder.cuda() + mel = mel.cuda() + +with torch.no_grad(): + audio = vocoder(mel) +``` + ## Inference - `python inference.py -p [checkpoint path] -i [input mel path]` diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..4df2f3d --- /dev/null +++ b/hubconf.py @@ -0,0 +1,40 @@ +dependencies = ['torch'] +from model.generator import Generator + +model_params = { + 'nvidia_tacotron2_LJ11_epoch3200': { + 'mel_channel': 80, + 'model_url': '', + }, +} + + +def melgan(model_name='nvidia_tacotron2_LJ11_epoch3200', pretrained=True, progress=True): + params = model_params[model_name] + model = Generator(params['mel_channel']) + + if pretrained: + state_dict = torch.hub.load_state_dict_from_url(params['model_url'], + progress=progress) + model.load_state_dict(state_dict['model_g']) + + model.eval(inference=True) + + return model + + +if __name__ == '__main__': + vocoder = torch.hub.load('seungwonpark/melgan', 'melgan') + mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here + + print('Input mel-spectrogram shape: {}'.format(mel.shape)) + + if torch.cuda.is_available(): + print('Moving data & model to GPU') + vocoder = vocoder.cuda() + mel = mel.cuda() + + with torch.no_grad(): + audio = vocoder.inference(mel) + + print('Output audio shape: {}'.format(audio.shape)) diff --git a/inference.py b/inference.py index 88e9b66..78d757f 100644 --- a/inference.py +++ b/inference.py @@ -20,7 +20,7 @@ def main(args): model = Generator(hp.audio.n_mel_channels).cuda() model.load_state_dict(checkpoint['model_g']) - model.eval() + model.eval(inference=False) with torch.no_grad(): for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))): @@ -29,17 +29,7 @@ def main(args): mel = mel.unsqueeze(0) mel = mel.cuda() - # pad input mel with zeros to cut artifact - # see /~https://github.com/seungwonpark/melgan/issues/8 - zero = torch.full((1, hp.audio.n_mel_channels, 10), -11.5129).cuda() - mel = torch.cat((mel, zero), axis=2) - - audio = model(mel) - audio = audio.squeeze() # collapse all dimension except time axis - audio = audio[:-(hp.audio.hop_length*10)] - audio = MAX_WAV_VALUE * audio - audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE) - audio = audio.short() + audio = model.inference(hp, mel) audio = audio.cpu().detach().numpy() out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch']) diff --git a/model/generator.py b/model/generator.py index c6c9c4a..82d60d8 100644 --- a/model/generator.py +++ b/model/generator.py @@ -5,10 +5,13 @@ from .res_stack import ResStack #from res_stack import ResStack +MAX_WAV_VALUE = 32768.0 + class Generator(nn.Module): def __init__(self, mel_channel): super(Generator, self).__init__() + self.mel_channel = mel_channel self.generator = nn.Sequential( nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1, padding=3)), @@ -42,6 +45,36 @@ def forward(self, mel): mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram return self.generator(mel) + def eval(self, inference=False): + super(Generator, self).eval() + + # don't remove weight norm while validation in training loop + if inference: + self.remove_weight_norm() + + def remove_weight_norm(self): + for idx, layer in enumerate(self.generator): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except: + layer.remove_weight_norm() + + def inference(self, hp, mel): + # pad input mel with zeros to cut artifact + # see /~https://github.com/seungwonpark/melgan/issues/8 + zero = torch.full((1, hp.audio.n_mel_channels, 10), -11.5129).to(mel.device) + mel = torch.cat((mel, zero), axis=2) + + audio = self.forward(mel) + audio = audio.squeeze() # collapse all dimension except time axis + audio = audio[:-(hp.audio.hop_length*10)] + audio = MAX_WAV_VALUE * audio + audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) + audio = audio.short() + + return audio + ''' to run this, fix diff --git a/model/res_stack.py b/model/res_stack.py index 6512409..37d9fc3 100644 --- a/model/res_stack.py +++ b/model/res_stack.py @@ -22,3 +22,8 @@ def forward(self, x): for layer in self.layers: x = x + layer(x) return x + + def remove_weight_norm(self): + for layer in self.layers: + nn.utils.remove_weight_norm(layer[1]) + nn.utils.remove_weight_norm(layer[3])