لود تصویر با Opencv و ذخیره در TFRecord تنسورفلو - هفت خط کد انجمن پرسش و پاسخ برنامه نویسی

لود تصویر با Opencv و ذخیره در TFRecord تنسورفلو

0 امتیاز
سلام.

آیا امکانش وجود داره که به جای لود تصویر با تنسورفلو با Opencv تصویر را لود کنیم و همان تصویر را در TFRecord ذخیره و بازیابی کنیم؟
سوال شده آذر 17, 1396  بوسیله ی ثریا (امتیاز 126)   6 24 30

1 پاسخ

+1 امتیاز
 
بهترین پاسخ
def loadImage(file_name):
    image = cv2.imread(file_name,1)
    image = cv2.resize(image,(224,224))
    return image.tobytes()


def int64Feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def bytesFeature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def write(tf_file_name,path):
    writer = tf.python_io.TFRecordWriter(tf_file_name)
    query = path+r"\*\*.jpg"

    files_name=  glob.glob(query)
    label = 1

    for i in range(len(files_name)):
        image = loadImage(files_name[i])
        example = tf.train.Example(features=tf.train.Features(feature={
        'label': int64Feature(label),
        'image': bytesFeature(image)
         }))
        writer.write(example.SerializeToString())

    writer.close()

def loadFromSer(value):
    features = tf.parse_single_example(value,features={
        'image': tf.FixedLenFeature([], tf.string),
        'label':tf.FixedLenFeature([],tf.int64)

    })

    image = tf.decode_raw(features['image'],tf.uint8)
    label = tf.cast(features['label'], tf.int32)

    image = tf.reshape(image,[224,224,3])

    return  image,label


def read(tf_file_name):
    filenames = [tf_file_name]
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(loadFromSer)
    dataset = dataset.batch(5)
    dataset = dataset.shuffle(buffer_size=10)
    dataset = dataset.repeat(2)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    sess  = tf.Session()
    
    for _ in range(10):
        sess.run(iterator.initializer)
        while True:
            try:
                images,labels = sess.run(next_element)
                
                cv2.imshow("view",images[0])
                cv2.waitKey(0)

            except tf.errors.OutOfRangeError:
                break



def main():

    tf_file_name = r"train.tfrecord"
    folder = r"image_path"

    write(tf_file_name,folder)
    read(tf_file_name)

if __name__ == "__main__":


    main()

 

پاسخ داده شده آذر 18, 1396 بوسیله ی مصطفی ساتکی (امتیاز 21,998)   24 34 75
انتخاب شد آذر 18, 1396 بوسیله ی ثریا
...