-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_acquisition.py
189 lines (143 loc) · 5.77 KB
/
data_acquisition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
This script downloads tracks with valid metadata in batches from captain-hammer.
Raw wav files created are then used with feature_extraction.py
"""
import logging
import sys
from pathlib import Path
from tempfile import TemporaryFile
import requests
from pydub import AudioSegment
# logging configuration
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
stdout_log_handler = logging.StreamHandler(sys.stdout)
stdout_log_handler.setLevel(logging.DEBUG)
logger.addHandler(stdout_log_handler)
BEETS_API_ROOT = 'http://captain-hammer.local:8337/item/'
TEST_WAVS_DIRECTORY = Path.cwd().joinpath('data/test/wavs/')
# bpm acceptable range
LOWER_BOUND = 120
UPPER_BOUND = 130
def genre_is_of_interest(genre):
"""
Args:
genre (str): Description
"""
logger.info('genre: {}'.format(genre))
for g in ['techno', 'house', 'dance', 'tech']:
if g in genre:
return True
return False
def bpm_is_in_range(bpm, lower_bound, upper_bound):
""" Whether bpm is in acceptable range
"""
if lower_bound <= bpm <= upper_bound:
return True
return False
def get_beets_track_bpm_and_format_tags(beets_track_url):
""" Downloads metadata for a given track id from
server and extracts bpm and format data.
Args: beets_id (int or str) : numeric character correspoding to
the track's id on server
Returns: format (str) : the file format of the track.
bpm (int) : the int floored bpm of the track
"""
logger.info('Getting beets track <%s> metadata…', beets_track_url)
metadata = requests.get(beets_track_url).json()
bpm = int(metadata.get('bpm', 0)) or None
genre = metadata.get('genre').lower()
if genre is None:
return None
logger.info('genre: {}'.format(genre.lower()))
# to do: flesh out
'''
if not is_of_interest(genre):
return None
'''
file_format = metadata.get('format', '').lower() or None
return bpm, file_format, genre
def download_beets_track_file(beets_track_url):
"""Downloads the beets track audio and returns a temporary file handle.
Args: beats_track_url (str) : url of track on server
Returns: file handle
"""
logger.info('Downloading beets track <%s> audio file…', beets_track_url)
f = TemporaryFile()
f.write(requests.get(beets_track_url + '/file').content)
f.seek(0) # in case we want to read this later
return f
def convert_mp3_to_wav_file(mp3_file):
"""Converts downloaded mp3 to wav."""
logger.debug('Converting %s to WAV…', mp3_file)
sound = AudioSegment.from_mp3(mp3_file)
sound = sound.set_channels(1)
wav_file = TemporaryFile()
sound.export(wav_file, format="wav")
wav_file.seek(0)
return wav_file
def download_all_beets_tracks():
"""Download all beets tracks with a non-zero BPM.
"""
# see <http://beets.readthedocs.io/en/v1.4.5/reference/query.html#query-term-negation>
beets_tracks_with_bpm = requests.get(BEETS_API_ROOT + 'query/^bpm:0').json()['results']
for track in beets_tracks_with_bpm:
beets_track_url = BEETS_API_ROOT + str(track['id'])
bpm = track['bpm']
file_format = ['format']
audio_file = download_beets_track_file(beets_track_url)
if file_format == 'MP3':
wav_file = convert_mp3_to_wav_file(audio_file)
audio_file.close()
audio_file = wav_file # switcheroo
TEST_WAVS_DIRECTORY.joinpath('%s.wav', track['id']).write_bytes(audio_file.read())
audio_file.close()
def downnsample_wav(src, dst, inrate=44100, outrate=16000,
inchannels=2, outchannels=1):
s_read = wave.open(src, 'r')
s_write = wave.open(dst, 'w')
n_frames = s_read.getnframes()
data = s_read.readframes(n_frames)
converted = audioop.ratecv(data, 2, inchannels, inrate, outrate, None)
if outchannels == 1:
converted = audioop.tomono(converted[0], 2, 1, 0)
s_write.setparams((outchannels, 2, outrate, 0, 'NONE', 'Uncompressed'))
s_write.writeframes(converted)
s_read.close()
s_write.close()
def main(beets_ids):
"""Downloads a batch of tracks from server in WAV format.
Downloads metadata and checks for bpm tag and format.
If bpm tag does not exist/is zero, track is not downloaded.
If track is not in WAV format, convert to it first.
dawdwc
Args (list[str]) : list of id's to download from server
"""
counter = 0
for beets_id in beets_ids:
counter += 1
beets_track_url = BEETS_API_ROOT + str(beets_id)
bpm, file_format, genre = get_beets_track_bpm_and_format_tags(beets_track_url)
if not bpm or not file_format:
logger.warning('No BPM/file format for <%s> (bpm=%s, format=%s).',
beets_track_url, bpm, file_format)
continue
if not genre_is_of_interest(genre):
continue
if not bpm_is_in_range(bpm, LOWER_BOUND, UPPER_BOUND):
continue
logger.debug('Beets track <%s> has bpm=%s, format=%s.',
beets_track_url, bpm, file_format)
if file_format == 'mp3':
mp3_file = download_beets_track_file(beets_track_url)
wav_file = convert_mp3_to_wav_file(mp3_file)
mp3_file.close() # TemporaryFile gets deleted on close
elif file_format == 'wav':
wav_file = download_beets_track_file(beets_track_url)
# TODO: pass file handle directly to next program
# copy WAV data to TEST_WAVS_DIRECTORY
TEST_WAVS_DIRECTORY.joinpath('%s_%s_%s.wav' % (genre, beets_id, bpm)).write_bytes(wav_file.read())
wav_file.close() # remember to close TemporaryFile for deletion
print(counter, beets_id)
if __name__ == '__main__':
main(list(range(1, 1001)))