-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathvggish_train_demo.py
195 lines (161 loc) · 7.82 KB
/
vggish_train_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A simple demonstration of running VGGish in training mode.
This is intended as a toy example that demonstrates how to use the VGGish model
definition within a larger model that adds more layers on top, and then train
the larger model. If you let VGGish train as well, then this allows you to
fine-tune the VGGish model parameters for your application. If you don't let
VGGish train, then you use VGGish as a feature extractor for the layers above
it.
For this toy task, we are training a classifier to distinguish between three
classes: sine waves, constant signals, and white noise. We generate synthetic
waveforms from each of these classes, convert into shuffled batches of log mel
spectrogram examples with associated labels, and feed the batches into a model
that includes VGGish at the bottom and a couple of additional layers on top. We
also plumb in labels that are associated with the examples, which feed a label
loss used for training.
Usage:
# Run training for 100 steps using a model checkpoint in the default
# location (vggish_model.ckpt in the current directory). Allow VGGish
# to get fine-tuned.
$ python vggish_train_demo.py --num_batches 100
# Same as before but run for fewer steps and don't change VGGish parameters
# and use a checkpoint in a different location
$ python vggish_train_demo.py --num_batches 50 \
--train_vggish=False \
--checkpoint /path/to/model/checkpoint
"""
from __future__ import print_function
from random import shuffle
import numpy as np
import tensorflow as tf
import vggish_input
import vggish_params
import vggish_slim
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_integer(
'num_batches', 30,
'Number of batches of examples to feed into the model. Each batch is of '
'variable size and contains shuffled examples of each class of audio.')
flags.DEFINE_boolean(
'train_vggish', True,
'If Frue, allow VGGish parameters to change during training, thus '
'fine-tuning VGGish. If False, VGGish parameters are fixed, thus using '
'VGGish as a fixed feature extractor.')
flags.DEFINE_string(
'checkpoint', 'vggish_model.ckpt',
'Path to the VGGish checkpoint file.')
FLAGS = flags.FLAGS
_NUM_CLASSES = 3
def _get_examples_batch():
"""Returns a shuffled batch of examples of all audio classes.
Note that this is just a toy function because this is a simple demo intended
to illustrate how the training code might work.
Returns:
a tuple (features, labels) where features is a NumPy array of shape
[batch_size, num_frames, num_bands] where the batch_size is variable and
each row is a log mel spectrogram patch of shape [num_frames, num_bands]
suitable for feeding VGGish, while labels is a NumPy array of shape
[batch_size, num_classes] where each row is a multi-hot label vector that
provides the labels for corresponding rows in features.
"""
# Make a waveform for each class.
num_seconds = 5
sr = 44100 # Sampling rate.
t = np.linspace(0, num_seconds, int(num_seconds * sr)) # Time axis.
# Random sine wave.
freq = np.random.uniform(100, 1000)
sine = np.sin(2 * np.pi * freq * t)
# Random constant signal.
magnitude = np.random.uniform(-1, 1)
const = magnitude * t
# White noise.
noise = np.random.normal(-1, 1, size=t.shape)
# Make examples of each signal and corresponding labels.
# Sine is class index 0, Const class index 1, Noise class index 2.
sine_examples = vggish_input.waveform_to_examples(sine, sr)
sine_labels = np.array([[1, 0, 0]] * sine_examples.shape[0])
const_examples = vggish_input.waveform_to_examples(const, sr)
const_labels = np.array([[0, 1, 0]] * const_examples.shape[0])
noise_examples = vggish_input.waveform_to_examples(noise, sr)
noise_labels = np.array([[0, 0, 1]] * noise_examples.shape[0])
# Shuffle (example, label) pairs across all classes.
all_examples = np.concatenate((sine_examples, const_examples, noise_examples))
print('all_examples shape:', all_examples.shape)
all_labels = np.concatenate((sine_labels, const_labels, noise_labels))
print('all_labels shape:', all_labels.shape)
labeled_examples = list(zip(all_examples, all_labels))
shuffle(labeled_examples)
# Separate and return the features and labels.
features = [example for (example, _) in labeled_examples]
labels = [label for (_, label) in labeled_examples]
return (features, labels)
def main(_):
with tf.Graph().as_default(), tf.Session() as sess:
# Define VGGish.
embeddings = vggish_slim.define_vggish_slim(FLAGS.train_vggish)
# Define a shallow classification model and associated training ops on top
# of VGGish.
with tf.variable_scope('mymodel'):
# Add a fully connected layer with 100 units.
num_units = 100
fc = slim.fully_connected(embeddings, num_units)
# Add a classifier layer at the end, consisting of parallel logistic
# classifiers, one per class. This allows for multi-class tasks.
logits = slim.fully_connected(
fc, _NUM_CLASSES, activation_fn=None, scope='logits')
tf.sigmoid(logits, name='prediction')
# Add training ops.
with tf.variable_scope('train'):
global_step = tf.Variable(
0, name='global_step', trainable=False,
collections=[tf.GraphKeys.GLOBAL_VARIABLES,
tf.GraphKeys.GLOBAL_STEP])
# Labels are assumed to be fed as a batch multi-hot vectors, with
# a 1 in the position of each positive class label, and 0 elsewhere.
labels = tf.placeholder(
tf.float32, shape=(None, _NUM_CLASSES), name='labels')
# Cross-entropy label loss.
xent = tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits, labels=labels, name='xent')
loss = tf.reduce_mean(xent, name='loss_op')
tf.summary.scalar('loss', loss)
# We use the same optimizer and hyperparameters as used to train VGGish.
optimizer = tf.train.AdamOptimizer(
learning_rate=vggish_params.LEARNING_RATE,
epsilon=vggish_params.ADAM_EPSILON)
optimizer.minimize(loss, global_step=global_step, name='train_op')
# Initialize all variables in the model, and then load the pre-trained
# VGGish checkpoint.
sess.run(tf.global_variables_initializer())
vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint)
# Locate all the tensors and ops we need for the training loop.
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
labels_tensor = sess.graph.get_tensor_by_name('mymodel/train/labels:0')
global_step_tensor = sess.graph.get_tensor_by_name(
'mymodel/train/global_step:0')
loss_tensor = sess.graph.get_tensor_by_name('mymodel/train/loss_op:0')
train_op = sess.graph.get_operation_by_name('mymodel/train/train_op')
# The training loop.
for _ in range(FLAGS.num_batches):
(features, labels) = _get_examples_batch()
[num_steps, loss, _] = sess.run(
[global_step_tensor, loss_tensor, train_op],
feed_dict={features_tensor: features, labels_tensor: labels})
print('Step %d: loss %g' % (num_steps, loss))
if __name__ == '__main__':
tf.app.run()