def loadImage(file_name):
image = cv2.imread(file_name,1)
image = cv2.resize(image,(224,224))
return image.tobytes()
def int64Feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytesFeature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def write(tf_file_name,path):
writer = tf.python_io.TFRecordWriter(tf_file_name)
query = path+r"\*\*.jpg"
files_name= glob.glob(query)
label = 1
for i in range(len(files_name)):
image = loadImage(files_name[i])
example = tf.train.Example(features=tf.train.Features(feature={
'label': int64Feature(label),
'image': bytesFeature(image)
}))
writer.write(example.SerializeToString())
writer.close()
def loadFromSer(value):
features = tf.parse_single_example(value,features={
'image': tf.FixedLenFeature([], tf.string),
'label':tf.FixedLenFeature([],tf.int64)
})
image = tf.decode_raw(features['image'],tf.uint8)
label = tf.cast(features['label'], tf.int32)
image = tf.reshape(image,[224,224,3])
return image,label
def read(tf_file_name):
filenames = [tf_file_name]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(loadFromSer)
dataset = dataset.batch(5)
dataset = dataset.shuffle(buffer_size=10)
dataset = dataset.repeat(2)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()
for _ in range(10):
sess.run(iterator.initializer)
while True:
try:
images,labels = sess.run(next_element)
cv2.imshow("view",images[0])
cv2.waitKey(0)
except tf.errors.OutOfRangeError:
break
def main():
tf_file_name = r"train.tfrecord"
folder = r"image_path"
write(tf_file_name,folder)
read(tf_file_name)
if __name__ == "__main__":
main()