-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathvggish_slim.py
163 lines (136 loc) · 7.15 KB
/
vggish_slim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the 'VGGish' model used to generate AudioSet embedding features.
The public AudioSet release (https://research.google.com/audioset/download.html)
includes 128-D features extracted from the embedding layer of a VGG-like model
that was trained on a large Google-internal YouTube dataset. Here we provide
a TF-Slim definition of the same model, without any dependences on libraries
internal to Google. We call it 'VGGish'.
Note that we only define the model up to the embedding layer, which is the
penultimate layer before the final classifier layer. We also provide various
hyperparameter values (in vggish_params.py) that were used to train this model
internally.
For comparison, here is TF-Slim's VGG definition:
/~https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py
"""
import tensorflow as tf
import vggish_params as params
slim = tf.contrib.slim
def define_vggish_slim(training=False):
"""Defines the VGGish TensorFlow model.
All ops are created in the current default graph, under the scope 'vggish/'.
The input is a placeholder named 'vggish/input_features' of type float32 and
shape [batch_size, num_frames, num_bands] where batch_size is variable and
num_frames and num_bands are constants, and [num_frames, num_bands] represents
a log-mel-scale spectrogram patch covering num_bands frequency bands and
num_frames time frames (where each frame step is usually 10ms). This is
produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET).
The output is an op named 'vggish/embedding' which produces the activations of
a 128-D embedding layer, which is usually the penultimate layer when used as
part of a full model with a final classifier layer.
Args:
training: If true, all parameters are marked trainable.
Returns:
The op 'vggish/embeddings'.
"""
# Defaults:
# - All weights are initialized to N(0, INIT_STDDEV).
# - All biases are initialized to 0.
# - All activations are ReLU.
# - All convolutions are 3x3 with stride 1 and SAME padding.
# - All max-pools are 2x2 with stride 2 and SAME padding.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_initializer=tf.truncated_normal_initializer(
stddev=params.INIT_STDDEV),
biases_initializer=tf.zeros_initializer(),
activation_fn=tf.nn.relu,
trainable=training), \
slim.arg_scope([slim.conv2d],
kernel_size=[3, 3], stride=1, padding='SAME'), \
slim.arg_scope([slim.max_pool2d],
kernel_size=[2, 2], stride=2, padding='SAME'), \
tf.variable_scope('vggish'):
# Input: a batch of 2-D log-mel-spectrogram patches.
features = tf.placeholder(
tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS),
name='input_features')
# Reshape to 4-D so that we can convolve a batch with conv2d().
net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1])
# The VGG stack of alternating convolutions and max-pools.
net = slim.conv2d(net, 64, scope='conv1')
net = slim.max_pool2d(net, scope='pool1')
net = slim.conv2d(net, 128, scope='conv2')
net = slim.max_pool2d(net, scope='pool2')
net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3')
net = slim.max_pool2d(net, scope='pool3')
net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4')
net = slim.max_pool2d(net, scope='pool4')
# Flatten before entering fully-connected layers
net = slim.flatten(net)
net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1')
# The embedding layer.
net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2')
return tf.identity(net, name='embedding')
def load_vggish_slim_checkpoint(session, checkpoint_path):
"""Loads a pre-trained VGGish-compatible checkpoint.
This function can be used as an initialization function (referred to as
init_fn in TensorFlow documentation) which is called in a Session after
initializating all variables. When used as an init_fn, this will load
a pre-trained checkpoint that is compatible with the VGGish model
definition. Only variables defined by VGGish will be loaded.
Args:
session: an active TensorFlow session.
checkpoint_path: path to a file containing a checkpoint that is
compatible with the VGGish model definition.
"""
# Get the list of names of all VGGish variables that exist in
# the checkpoint (i.e., all inference-mode VGGish variables).
with tf.Graph().as_default():
define_vggish_slim(training=False)
vggish_var_names = [v.name for v in tf.global_variables()]
# Get the list of all currently existing variables that match
# the list of variable names we just computed.
vggish_vars = [v for v in tf.global_variables() if v.name in vggish_var_names]
# Use a Saver to restore just the variables selected above.
saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained',
write_version=1)
saver.restore(session, checkpoint_path)
def load_defined_vggish_slim_checkpoint(session, checkpoint_path):
"""Loads a pre-trained VGGish-compatible checkpoint.
This function can be used as an initialization function (referred to as
init_fn in TensorFlow documentation) which is called in a Session after
initializating all variables. When used as an init_fn, this will load
a pre-trained checkpoint that is compatible with the VGGish model
definition. Only variables defined by VGGish will be loaded.
Args:
session: an active TensorFlow session with an exist default graph
checkpoint_path: path to a file containing a checkpoint that is
compatible with the VGGish model definition.
"""
# Get the list of names of all VGGish variables that exist in
# the checkpoint (i.e., all inference-mode VGGish variables).
with tf.Graph().as_default():
define_vggish_slim(training=False)
vggish_var_names = [v.name for v in tf.global_variables()]
# Get list of variables from exist graph which passed by session
with session.graph.as_default():
global_variables = tf.global_variables()
# Get the list of all currently existing variables that match
# the list of variable names we just computed.
vggish_vars = [v for v in global_variables if v.name in vggish_var_names]
# Use a Saver to restore just the variables selected above.
saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained',
write_version=1)
saver.restore(session, checkpoint_path)