In this implementation, I convert image and label to TFrecord file, and then read the file using tf.data
- TFRecord is a Tensorflow's standard file format, which stores data in binary format
- It is recommended to use for machine learning projects, especially one that involves with big dataset
- Since data are stored in binary format, it is faster and more flexible
- For a very large dataset whose size is too big to fit into your memory, it is easy to read partial data from the file
- There are two mainstep to implement TFRecord: convert data to TFRecord, and read TFRecord data
To convert data to TFRecord file format, there are five main steps as below:
writer = tf.python_io.TFRecordWriter(tfrecord_filename)
img = load_image(os.path.join(img_directory,img_addr), img_width, img_height)
img, shape = get_image_binary(img)
if img_addr.find('cat') != -1:
label = 1
else:
label = 0
feature = {'label': _int64_feature(label),
'image': _bytes_feature(img),
'shape':_bytes_feature(shape)}
features = tf.train.Features(feature=feature)
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
Reading data using tf.data is much easier than traditional method. You can follow the instructions below:
def _parse(serialized_data):
features = {'label': tf.FixedLenFeature([], tf.int64),
'image': tf.FixedLenFeature([], tf.string),
'shape': tf.FixedLenFeature([], tf.string)}
features = tf.parse_single_example(serialized_data,
features)
img = tf.decode_raw(features['image'],tf.uint8)
shape = tf.decode_raw(features['shape'],tf.int32)
img = tf.reshape(img, shape)
return img, features['label']
dataset = tf.data.TFRecordDataset(tfrecord_filename)
dataset = dataset.map(_parse)
iterator = dataset.make_initializable_iterator()
img, label = iterator.get_next()
To test the code, just run tf_record.py and you will see the tfrecord file in the working directory