شبکه عمیق برای کلاسفیکیشن MNIST - هفت خط کد انجمن پرسش و پاسخ برنامه نویسی

شبکه عمیق برای کلاسفیکیشن MNIST

+1 امتیاز
سلام.پیشنهاد شما برای کلاسیفکیشن MNIST که خصوصیات زیر را داشته باشه

۱- ساده جهت پیاده سازی

۲- پارامترهای کمی داشته باشه یا همان بارمحاسبای کم

۳-دقت خوبی داشته باشه

۴-با کراس پیاده سازی شه
سوال شده تیر 7, 1398  بوسیله ی ابید (امتیاز 781)   19 90 106

1 پاسخ

0 امتیاز

البته بنچ مارکی برای mnist وجود داره که دقت ها هر روش به همراه مقاله در آن ذکر شده ولی من با این روش دم دستی دقتی 99.4 گرفتم رو داده های تست.

import numpy as np
from keras.models import Sequential
from keras.layers import Activation, Dense, Dropout
from keras.layers import Conv2D, MaxPooling2D, Flatten
from keras.utils import to_categorical, plot_modelfrom keras.datasets import mnist


(x_train, y_train), (x_test, y_test) = mnist.load_data()


num_labels = len(np.unique(y_train))


y_train = to_categorical(y_train)
y_test = to_categorical(y_test)


image_size = x_train.shape[1]


x_train = np.reshape(x_train,[-1, image_size, image_size, 1])
x_test = np.reshape(x_test,[-1, image_size, image_size, 1])

x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# network parameters
input_shape = (image_size, image_size, 1)
batch_size = 128
kernel_size = 3
pool_size = 2
filters = 64
dropout = 0.2


model = Sequential()

model.add(Conv2D(filters=filters,
kernel_size=kernel_size,
activation='relu',
input_shape=input_shape))
model.add(MaxPooling2D(pool_size))

model.add(Conv2D(filters=filters,
kernel_size=kernel_size,
activation='relu'))
model.add(MaxPooling2D(pool_size))

model.add(Conv2D(filters=filters,
kernel_size=kernel_size,
activation='relu'))

model.add(Flatten())

model.add(Dense(num_labels))
model.add(Activation('softmax'))

model.summary()

plot_model(model, to_file='cnn-mnist.png', show_shapes=True)

model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])


model.fit(x_train, y_train, epochs=10, batch_size=batch_size)
loss, acc = model.evaluate(x_test, y_test, batch_size=batch_size)
print("\nTest accuracy: %.1f%%" % (100.0 * acc))

 

پاسخ داده شده تیر 9, 1398 بوسیله ی farnoosh (امتیاز 8,362)   20 44 59
...