خطا در inference مدل freeze شده در تنسورفلو - هفت خط کد انجمن پرسش و پاسخ برنامه نویسی

خطا در inference مدل freeze شده در تنسورفلو

0 امتیاز

سلام .

با استفاده از کد موجود در این پست مدل فریز شده را inference می کنم پایتون خطای زیر را میده به ظاهر چیزی غلط نیز به هر حال اجرا نمیشه.

graph =  'model name'
 
with tf.gfile.FastGFile(graph,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
 
 
with tf.Session() as sess:
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
    np.random.seed(234)
    inp = np.random.standard_normal([1, 224, 224, 3]).astype(np.float32)
    out = sess.run(sess.graph.get_tensor_by_name('dense_2/Softmax:0'),
                   feed_dict={'input_1:0': inp})
 
    print(out)

TypeError: unhashable type: 'numpy.ndarray' tensorflow

سوال شده مرداد 21, 1399  بوسیله ی همایون (امتیاز 220)   10 38 43

1 پاسخ

+1 امتیاز
 
بهترین پاسخ

من الان تست کردم به این صورت استفاده کنید کار می کنه :

  image_size = 224
  num_channels = 3
  images = []


  filename = image_path
  image = cv2.imread(filename)

  image = cv2.resize(image, (image_size, image_size), cv2.INTER_LINEAR)
  image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
  images.append(image)
  images = np.array(images, dtype=np.uint8)
  images = images.astype('float32')
  images = np.multiply(images, 1.0 / 255.0)

  x_batch = images.reshape(1, image_size, image_size, num_channels)

  graph = 'model name'
  with tf.gfile.GFile(frozen_graph, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

  with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def,
                        input_map=None,
                        return_elements=None,
                        name=""
                        )


  
  y_pred = graph.get_tensor_by_name("dense_2/Softmax:0")

  x = graph.get_tensor_by_name("input_1:0")
  y_test_images = np.zeros((1, 2))
  sess = tf.Session(graph=graph)
  feed_dict_testing = {x: x_batch}
  result = sess.run(y_pred, feed_dict=feed_dict_testing)

  max_index =  np.argmax(result, axis=1)

  print(max_index,result[0,max_index])

 

پاسخ داده شده مرداد 22, 1399 بوسیله ی عباس مولایی (امتیاز 2,754)   1 5 13
انتخاب شد شهریور 8, 1402 بوسیله ی farnoosh
...