-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchexpert_dataset.py
252 lines (201 loc) · 10.5 KB
/
chexpert_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
"""Make it easier to work with the CheXpert .csv files.
- Create explicit columns for patient ID, study nubmer, and view number, extracted from the file
paths.
- Adjust the column data types to reduce memory usage.
- Add a column for age groups to help cross-sectional analysis.
- Labels are encoded as integers, including the "no mention" (encoded as an empty string in the
validation set is converted for consistency.
Using from the command line: ``python3 -m preprocess > chexpert.csv``
From another module::
import chexpert_dataset as cd
cxdata = cd.CheXpertDataset()
cxdata.fix_dataset() # optional
cxdata.df.head()
IMPORTANT: because we are using categories, set observed=True when using groupby with the
categorical columns to avoid surprises (/~https://github.com/pandas-dev/pandas/issues/17594)
"""
# pylint: disable=too-few-public-methods
import logging
import os
import re
import pandas as pd
import imagesize
# Dataset invariants that must hold when we manipulate it (groupby, pivot_table, filters, etc.)
# Numbers come from analyzing the .csv files shipped with the dataset (see chexpert_csv_eda.py)
# If `assert` start to fail in the code, either the code is broken or the dataset has changed
PATIENT_NUM_TRAINING = 64_540
PATIENT_NUM_VALIDATION = 200
PATIENT_NUM_TOTAL = PATIENT_NUM_VALIDATION + PATIENT_NUM_TRAINING
STUDY_NUM_TRAINING = 187_641
STUDY_NUM_VALIDATION = 200
STUDY_NUM_TOTAL = STUDY_NUM_TRAINING + STUDY_NUM_VALIDATION
IMAGE_NUM_TRAINING = 223_414
IMAGE_NUM_VALIDATION = 234
IMAGE_NUM_TOTAL = IMAGE_NUM_VALIDATION + IMAGE_NUM_TRAINING
# Number of unique combinations of "patient id/age group"
# This number is larger than the number of patients because some patients have studies over multiple
# years, crossing age group - if we group by age group we need to take this into account when
# checking the consistency of the datasets we are working with
# See how it was calcuated in chexpert_statistics.py
PATIENT_NUM_TOTAL_BY_AGE_GROUP = 66_366
# Labels as used in the DataFrame
LABEL_POSITIVE = 1
LABEL_NEGATIVE = 0
LABEL_UNCERTAIN = -1
LABEL_NO_MENTION = -99
# Observations (must match the names in the .csv files)
OBSERVATION_NO_FINDING = 'No Finding'
OBSERVATION_PATHOLOGY = sorted(['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis',
'Pneumothorax', 'Pleural Effusion', 'Pleural Other'])
OBSERVATION_OTHER = [OBSERVATION_NO_FINDING, 'Fracture', 'Support Devices']
OBSERVATION_ALL = OBSERVATION_OTHER + OBSERVATION_PATHOLOGY
# Names of some commonly-used columns already in the dataset
COL_SEX = 'Sex'
COL_AGE = 'Age'
COL_FRONTAL_LATERAL = 'Frontal/Lateral'
COL_AP_PA = 'AP/PA'
# Names of the columns added with this code
COL_PATIENT_ID = 'Patient ID'
COL_STUDY_NUMBER = 'Study Number'
COL_VIEW_NUMBER = 'View Number'
COL_AGE_GROUP = 'Age Group'
COL_TRAIN_VALIDATION = 'Training/Validation'
# Values of columns added with this code
TRAINING = 'Training'
VALIDATION = 'Validation'
class CheXpertDataset:
"""An augmented version of the CheXPert dataset.
Create one DataFrame that combines the train.csv and valid.csv files, then augments it. The
combined DataFrame appends the following columns to the existing dataset columns:
- Patient number
- Study number
- View number
- Age group (MeSH age group - https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1794003/)
- "Train" or "Validation" image
It also normalizes the labels to 0, 1, and -1 by converting floating point labels to integer
(e.g. 0.0 to 0) and by filling in empty label columns with 0.
"""
def __init__(self, directory: str = None, add_image_size: bool = False, verbose: bool = True):
"""Populate the augmented dataset.
Once the class is initialized, the augmented dataset is available as a Pandas DataFrame in
the ``df`` class variable.
Args:
directory (str, optional): The directory where the dataset is saved, or ``None`` to
search for the directory. Defaults to None.
add_image_size (bool, optional): Add the image size (takes a few seconds). Defaults to
False.
verbose (bool, optional): Turn verbose logging on/off. Defaults to off.
"""
self.__init_logger(verbose)
self.__directory = directory
self.__add_image_size = add_image_size
self.__df = self.__get_augmented_chexpert()
@property
def df(self):
"""Return the DataFrame that contains the training and validation test sets.
Make a copy before modifying it. This code does not return a copy to increase performace.
"""
return self.__df
def fix_dataset(self):
"""Fix issues with the dataset (in place).
See code for what is fixed.
"""
# There is one record with sex 'Unknown'. There is only one image for that patient, so we
# don't have another record where the sex could be copied from. Change it to "Female"
# (it doesn't matter much which sex we pick because it is one record out of 200,000+).
self.df.loc[self.df.Sex == 'Unknown', ['Sex']] = 'Female'
self.df.Sex.cat.remove_unused_categories()
@ staticmethod
def find_directory() -> str:
"""Determine the directory where the dataset is stored.
There are two versions of the dataset, small and large. They are stored in
CheXpert-v1.0-small and CheXpert-v1.0-large respectively. To make the code generic, this
function finds out what version is installed.
Note: assumes that 1) only one of the versions is installed and 2) that it is at the same
level where this code is being executed.
Returns:
str: The name of the images directory or an empty string if it can't find one.
"""
for entry in os.scandir('.'):
if entry.is_dir() and re.match(r'CheXpert-v\d\.\d-', entry.name):
return entry.name
return ''
def __init_logger(self, verbose: bool):
"""Init the logger.
Args:
verbose (bool): Turn verbose logging on/off.
"""
self.__ch = logging.StreamHandler()
self.__ch.setFormatter(logging.Formatter('%(message)s'))
self.__logger = logging.getLogger(__name__)
self.__logger.addHandler(self.__ch)
self.__logger.setLevel(logging.INFO if verbose else logging.ERROR)
def __get_augmented_chexpert(self) -> pd.DataFrame:
"""Get and augmented vresion of the CheXpert dataset.
Add columns described in the file header and compacts the DataFrame to use less memory.
Raises:
RuntimeError: Cannot find the dataset directory and no directory was specified.
Returns:
pd.DataFrame: The dataset with the original and augmented columns.
"""
directory = CheXpertDataset.find_directory() \
if self.__directory is None else self.__directory
if not directory:
raise RuntimeError('Cannot find the CheXpert directory')
self.__logger.info('Using the dataset in %s', directory)
df = pd.concat(pd.read_csv(os.path.join(directory, f)) for f in ['train.csv', 'valid.csv'])
# Convert the "no mention" label to an integer representation
# IMPORTANT: assumes this is the only case of NaN after reading the .csv files
self.__logger.info('Converting "no mention" to integer')
df.fillna(LABEL_NO_MENTION, inplace=True)
# Add the patient ID column by extracting it from the filename
# Assume that the 'Path' column follows a well-defined format and extract from "patientNNN"
self.__logger.info('Adding patient ID')
df[COL_PATIENT_ID] = df.Path.apply(lambda x: int(x.split('/')[2][7:]))
# Add the study number column, also assuming that the 'Path' column is well-defined
self.__logger.info('Adding study number')
df[COL_STUDY_NUMBER] = df.Path.apply(lambda x: int(x.split('/')[3][5:]))
# Add the view number column, also assuming that the 'Path' column is well-defined
self.__logger.info('Adding view number')
view_regex = re.compile('/|_')
df[COL_VIEW_NUMBER] = df.Path.apply(lambda x: int(re.split(view_regex, x)[4][4:]))
# Add the MeSH age group column
# Best reference I found for that: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1794003/
# We have only complete years, so we can't use 'newborn'
# Also prefix with zero because visualizers sort by ASCII code, not numeric value
self.__logger.info('Adding age group')
bins = [0, 2, 6, 13, 19, 45, 65, 80, 120]
ages = ['(0-1) Infant', '(02-5) Preschool', '(06-12) Child', '(13-18) Adolescent',
'(19-44) Adult', '(45-64) Middle age', '(65-79) Aged', '(80+) Aged 80']
df[COL_AGE_GROUP] = pd.cut(df.Age, bins=bins, labels=ages, right=False)
# Add the train/validation column
self.__logger.info('Adding train/validation')
df[COL_TRAIN_VALIDATION] = df.Path.apply(lambda x: TRAINING if 'train' in x else VALIDATION)
# Add the image information column
if self.__add_image_size:
self.__logger.info('Adding image size (takes a few seconds)')
size = [imagesize.get(f) for f in df.Path]
df[['Width', 'Height']] = pd.DataFrame(size, index=df.index)
# Optimize memory usage: use categorical values and small integer when possible
# https://pandas.pydata.org/pandas-docs/stable/user_guide/scale.html
# IMPORTANT: because we are using categories, set observed=True when using groupby with
# these columns to avoid surprises (/~https://github.com/pandas-dev/pandas/issues/17594)
for c in [COL_SEX, COL_FRONTAL_LATERAL, COL_AP_PA, COL_AGE_GROUP, COL_TRAIN_VALIDATION]:
df[c] = df[c].astype('category')
for c in [COL_AGE, COL_PATIENT_ID, COL_STUDY_NUMBER, COL_VIEW_NUMBER]:
df[c] = df[c].astype('int32')
for c in OBSERVATION_ALL:
df[c] = df[c].astype('int8')
# A bare minimum amount of sanity checks
assert df[df[COL_TRAIN_VALIDATION] ==
TRAINING][COL_PATIENT_ID].nunique() == PATIENT_NUM_TRAINING
assert df[COL_PATIENT_ID].nunique() == PATIENT_NUM_TOTAL
return df
def main():
"""Entrypoint to test this file from the command line."""
chexpert = CheXpertDataset()
chexpert.fix_dataset()
print(chexpert.df.to_csv(index=False))
if __name__ == '__main__':
main()