تبدیل MNIST به TFRecord - هفت خط کد انجمن پرسش و پاسخ برنامه نویسی

تبدیل MNIST به TFRecord

0 امتیاز
سلام.

قصد دارم به صورت عملی تبدیل یک دیتاست تصویری را به TFRecord یاد بگیرم به نظرم رسید مثال ساده برای اینکار دیتاست MNIST باشه کسی هست که MNIST به TFRecord تبدیل کرده باشه؟
سوال شده آذر 6, 1396  بوسیله ی amin_sajadi (امتیاز 15)   2 2

1 پاسخ

0 امتیاز

سلام.تو مثالی که براتون قرار دادم از دیتاست MNIST است داده های هر 3 بخش train,test,validation را در فایل tfrecord ذخیره می کنه و سپس جهت نمونه فقط داده های بخش train را بارگذاری می کنه و نمایش میده. مسیر دیتاست MNIST و مسیری که قراره فایل های tfrecord را در آن ایجاد کنه را مشخص کنید.

import tensorflow as tf
import  numpy as np
from tensorflow.contrib.learn.python.learn.datasets import mnist
import  os
import cv2


def intt64Feature(value):
    return  tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def bytesFeature(value):
    return  tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def createMNIST2TFRecord(mnist_path,tf_path):
    data_splits = ["train","test","validation"]
    datasets = mnist.read_data_sets(mnist_path,dtype=tf.uint8,reshape=False,
                                    validation_size=1000)

    for d in range(len(data_splits)):
        cur_split = data_splits[d]
        print("saving "+cur_split)
        dataset = datasets[d]
        file_name = os.path.join(tf_path,cur_split+".tfrecord")
        print(file_name)
        writer = tf.python_io.TFRecordWriter(file_name)

        images_count = dataset.images.shape[0]
        for index in range(images_count) :
            cur_img =dataset.images[index]
            image = cur_img.tostring()
            example = tf.train.Example(features=tf.train.Features(
                feature={
                    'height':intt64Feature(dataset.images.shape[1]),
                    'width':intt64Feature(dataset.images.shape[2]),
                    'depth':intt64Feature(dataset.images.shape[3]),
                    'label':intt64Feature(dataset.labels[index]),
                    "raw_image":bytesFeature(image)
                }
            ))
            writer.write(example.SerializeToString())
        writer.close()


def loadMNISTFromTFRecord(file_name):

    data_iterator = tf.python_io.tf_record_iterator(file_name)

    while True:
        try:
            example_serialized = next(data_iterator)
            example = tf.train.Example()
            example.ParseFromString(example_serialized)

            width = example.features.feature['width'].int64_list.value[0]
            height = example.features.feature['height'].int64_list.value[0]
            depth = example.features.feature['depth'].int64_list.value[0]
            label = example.features.feature['label'].int64_list.value[0]
            image = example.features.feature['raw_image'].bytes_list.value

            flat_image = np.fromstring(image[0],np.uint8)
            reshaped_img = flat_image.reshape((height,width,-1))
            cv2.imshow("view",reshaped_img)
            cv2.waitKey(0)
        except tf.errors.OutOfRangeError:
            break


def main():
    mnist_path = r"D:\Database\MNIST"
    ft_path = r"D:\tf_test"
    createMNIST2TFRecord(mnist_path,ft_path)
    ft_train_file_name = os.path.join(ft_path,"train.tfrecord")
    loadMNISTFromTFRecord(ft_train_file_name)

if __name__ == "__main__":
    main()


 

پاسخ داده شده آذر 7, 1396 بوسیله ی مصطفی ساتکی (امتیاز 21,998)   24 34 75
...