{"id":7846,"date":"2019-08-24T06:00:11","date_gmt":"2019-08-23T21:00:11","guid":{"rendered":"http:\/\/www.gisdeveloper.co.kr\/?p=7846"},"modified":"2020-05-28T10:16:53","modified_gmt":"2020-05-28T01:16:53","slug":"dnn%ec%9d%84-%ec%9d%b4%ec%9a%a9%ed%95%9c-fashion-mnist-%eb%8d%b0%ec%9d%b4%ed%84%b0%ec%97%90-%eb%8c%80%ed%95%9c-classifier","status":"publish","type":"post","link":"http:\/\/www.gisdeveloper.co.kr\/?p=7846","title":{"rendered":"DNN\uc744 \uc774\uc6a9\ud55c Fashion-MNIST \ub370\uc774\ud130\uc5d0 \ub300\ud55c Classifier"},"content":{"rendered":"<p>\ucc98\uc74c \ub525\ub7ec\ub2dd\uc744 \ud14c\uc2a4\ud2b8 \ud558\uae30 \uc704\ud574 \ud754\ud788 \uc0ac\uc6a9\ud558\ub294 \ub370\uc774\ud130\ub294 MNIST \uc785\ub2c8\ub2e4. 0~9\uae4c\uc9c0\uc758 \uc190\uae00\uc528\uc5d0 \ub300\ud55c 28&#215;28 \ud06c\uae30\uc758 \uc774\ubbf8\uc9c0\uc785\ub2c8\ub2e4. \uc774\ubbf8\uc9c0 \ub370\uc774\ud130\uc640 \ud568\uaed8 \ub77c\ubca8 \ub370\uc774\ud130\ub3c4 \uc81c\uacf5\ub418\ubbc0\ub85c \ubc14\ub85c \ud65c\uc6a9\ud560 \uc218 \uc788\ub294 \ub370\uc774\ud130\uc785\ub2c8\ub2e4. MNSIT\uc5d0 \uc601\uac10\uc744 \ubc1b\uc740 <a href='https:\/\/github.com\/zalandoresearch\/fashion-mnist'>Fashion-MNIST<\/a>\ub77c\ub294 \ub370\uc774\ud130\uac00 \uc874\uc7ac\ud569\ub2c8\ub2e4. \ucd1d 10\uac00\uc9c0\uc758 \ud328\uc158 \uc544\uc774\ud15c\uc5d0 \ub300\ud55c \uc774\ubbf8\uc9c0\uc640 \ub77c\ubca8\uc785\ub2c8\ub2e4. \uc81c\uacf5\ub418\ub294 \uc774\ubbf8\uc9c0\uc758 \uc608\ub294 \uc544\ub798\uc640 \uac19\uc2b5\ub2c8\ub2e4.<\/p>\n<p><img loading=\"lazy\" decoding=\"async\" src=\"http:\/\/www.gisdeveloper.co.kr\/wp-content\/uploads\/2019\/08\/fashion-mnist-sprite.png\" alt=\"\" width=\"840\" height=\"840\" class=\"aligncenter size-full wp-image-7847\" \/><\/p>\n<p>\uc704 \uc774\ubbf8\uc9c0\ub4e4\uc5d0\uc11c \uc0c1\ub2e8\ubd80\ud130 3\uac1c\uc758 Rows\uc529\uc744 \ud55c \uadf8\ub8f9\uc73c\ub85c \ubb36\uc73c\uba74 \uac01\uac01\uc758 \uadf8\ub8f9\uc774 \uc758\ubbf8\ud558\ub294 \uac83\uc740 \uc544\ub798\uc758 \ud45c\uc640 \uadf8 \ud56d\ubaa9\uc758 \uc21c\uc11c\uac00 \uac19\uc2b5\ub2c8\ub2e4.<\/p>\n<p><center><\/p>\n<table style='width:160px'>\n<tr>\n<th style='width:40px'>Label<\/th>\n<th style='width:120px'>Description<\/th>\n<\/tr>\n<tr>\n<td>0<\/td>\n<td>T-shirt\/top<\/td>\n<\/tr>\n<tr>\n<td>1<\/td>\n<td>Trouser<\/td>\n<\/tr>\n<tr>\n<td>2<\/td>\n<td>Pullover<\/td>\n<\/tr>\n<tr>\n<td>3<\/td>\n<td>Dress<\/td>\n<\/tr>\n<tr>\n<td>4<\/td>\n<td>Coat<\/td>\n<\/tr>\n<tr>\n<td>5<\/td>\n<td>Sandal<\/td>\n<\/tr>\n<tr>\n<td>6<\/td>\n<td>Shirt<\/td>\n<\/tr>\n<tr>\n<td>7<\/td>\n<td>Sneaker<\/td>\n<\/tr>\n<tr>\n<td>8<\/td>\n<td>Bag<\/td>\n<\/tr>\n<tr>\n<td>9<\/td>\n<td>Ankle boot<\/td>\n<\/tr>\n<\/table>\n<p><\/center><\/p>\n<p>\ud6c8\ub828\uc744 \uc704\ud55c \uc774\ubbf8\uc9c0\uc640 \ub77c\ubca8\uc758 \uc218\ub294 \uac01\uac01 60,000\uac1c, \uc2dc\ud5d8\uc744 \uc704\ud55c \uc774\ubbf8\uc9c0\uc640 \ub77c\ubca8\uc758 \uc218\ub294 \uac01\uac01 10,000\uac1c\uc785\ub2c8\ub2e4. \uc774 \uae00\uc740 Fashion-MNIST\ub97c PyTorch\ub97c \uc774\uc6a9\ud574 \ud6c8\ub828\uc744 \uc2dc\ucf1c\ubcf4\ub294 \ucf54\ub4dc\uc640 \uadf8 \uacb0\uacfc\uc5d0 \ub300\ud55c \uc124\uba85\uc785\ub2c8\ub2e4.<\/p>\n<p>\uba3c\uc800 \ud544\uc694\ud55c \ub77c\uc774\ube0c\ub7ec\ub9ac\ub97c import \ud569\ub2c8\ub2e4.<\/p>\n<pre class=\"EnlighterJSRAW\" data-enlighter-language=\"python\">\r\nimport torch\r\nimport torch.nn as nn\r\nimport torchvision.datasets as dset\r\nimport torchvision.transforms as transforms\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom torch.utils.data import DataLoader\r\n<\/pre>\n<p>\ud558\uc774\ud37c \ud30c\ub77c\uba54\ud130\ub294 \ub2e4\uc74c\uacfc \uac19\uc2b5\ub2c8\ub2e4.<\/p>\n<pre class=\"EnlighterJSRAW\" data-enlighter-language=\"python\">\r\nbatch_size = 100\r\nnum_epochs = 250\r\nlearning_rate = 0.0001\r\n<\/pre>\n<p>\ud6c8\ub828\uc5d0 \ud544\uc694\ud55c \ub370\uc774\ud130\uc640 \uc2dc\ud5d8\uc5d0 \ud544\uc694\ud55c \ub370\uc774\ud130\ub97c \ub2e4\uc6b4\ub85c\ub4dc \ubc1b\uc544\uc57c \ud558\ub294\ub370, PyTorch\uc5d0\uc11c\ub294 \uc774\ub97c \uc704\ud55c \ub3c4\uad6c\ub97c \uc9c0\uc6d0\ud558\ubbc0\ub85c, \uc774\ub97c \ud65c\uc6a9\ud558\uc5ec \ub2e4\uc74c \ucf54\ub4dc\ucc98\ub7fc MNIST_Fashion \ud3f4\ub354\uc5d0 \ub370\uc774\ud130\ub97c \ub2e4\uc6b4\ubc1b\uace0 \ub370\uc774\ud130\ub97c \ud65c\uc6a9\ud560 \uc900\ube44\ub97c \ud569\ub2c8\ub2e4.<\/p>\n<pre class=\"EnlighterJSRAW\" data-enlighter-language=\"python\">\r\nroot = '.\/MNIST_Fashion'\r\ntransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])\r\ntrain_data = dset.FashionMNIST(root=root, train=True, transform=transform, download=True)\r\ntest_data = dset.FashionMNIST(root=root, train=False, transform=transform, download=True)\r\n\r\ntrain_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)\r\ntest_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)\r\n<\/pre>\n<p>\uc774\uc81c \ubaa8\ub378\uc744 \uc815\uc758(\uc544\ub798 \ucf54\ub4dc\uc758 4\ubc88 \ucf54\ub4dc\uc778 DNN \ud074\ub798\uc2a4\uc758 __init__ \ud568\uc218)\ud558\uace0 \uc0ac\uc6a9\ud560 \uc190\uc2e4\ud568\uc218(\uc544\ub798\uc758 37\ubc88 \ucf54\ub4dc)\uc640 \ucd5c\uc18c\uc758 \uc190\uc2e4\uac12\uc744 \uac00\uc9c0\ub294 \uac00\uc911\uce58\uc640 \ud3b8\ud5a5\uac12\uc744 \ucc3e\uae30 \uc704\ud574 \ubc29\ubc95(\uc544\ub798\uc758 38\ubc88 \ucf54\ub4dc) \ubc0f \ub9e8 \ucc98\uc74c \uac00\uc911\uce58(\uc544\ub798\uc758 30\ubc88 weights_init \ucf54\ub4dc\uc640 35\ubc88 \ucf54\ub4dc)\ub97c \ucd08\uae30\ud654\ud569\ub2c8\ub2e4.<\/p>\n<pre class=\"EnlighterJSRAW\" data-enlighter-language=\"python\">\r\ndevice = torch.device(device + \":0\")\r\n\r\nclass DNN(nn.Module):\r\n    def __init__(self):\r\n        super(DNN, self).__init__()\r\n\r\n        self.layer1 = nn.Sequential(\r\n            torch.nn.Linear(784, 256, bias=True),\r\n            torch.nn.BatchNorm1d(256),\r\n            torch.nn.ReLU()\r\n        )\r\n\r\n        self.layer2 = nn.Sequential(\r\n            torch.nn.Linear(256, 64, bias=True),\r\n            torch.nn.BatchNorm1d(64),\r\n            torch.nn.ReLU()\r\n        )\r\n\r\n        self.layer3 = nn.Sequential(\r\n            torch.nn.Linear(64, 10, bias=True)\r\n        )\r\n    \r\n    def forward(self, x):\r\n        x = x.view(x.size(0), -1) # flatten\r\n        x_out = self.layer1(x)\r\n        x_out = self.layer2(x_out)\r\n        x_out = self.layer3(x_out)\r\n        return x_out\r\n\r\ndef weights_init(m):\r\n    if isinstance(m, nn.Linear):\r\n        nn.init.xavier_normal_(m.weight) \r\n\r\nmodel = DNN().to(device)\r\nmodel.apply(weights_init)\r\n\r\ncriterion = torch.nn.CrossEntropyLoss().to(device)\r\noptimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\r\n<\/pre>\n<p>\uc704\uc5d0\uc11c \uc0ac\uc6a9\ud55c \ubaa8\ub378\uc740 DNN(Deep Neural Network)\uc73c\ub85c\uc368 \ub2e4\uc74c\uacfc \uac19\uc2b5\ub2c8\ub2e4.<\/p>\n<p><img decoding=\"async\" src=\"http:\/\/www.gisdeveloper.co.kr\/wp-content\/uploads\/2019\/08\/MNSIT_Fashion_Model.png\" alt=\"\" width=\"80%\" class=\"aligncenter size-full wp-image-7867\" \/><\/p>\n<p>\ub2e4\uc74c\uc740 \ud6c8\ub828(Train)\uc5d0 \ub300\ud55c \ucf54\ub4dc\uc785\ub2c8\ub2e4.<\/p>\n<pre class=\"EnlighterJSRAW\" data-enlighter-language=\"python\">\r\ncosts = []\r\ntotal_batch = len(train_loader)\r\nfor epoch in range(num_epochs):\r\n    total_cost = 0\r\n\r\n    for i, (imgs, labels) in enumerate(train_loader):\r\n        imgs, labels = imgs.to(device), labels.to(device)\r\n\r\n        outputs = model(imgs)\r\n        loss = criterion(outputs, labels)\r\n        \r\n        optimizer.zero_grad()\r\n        loss.backward()\r\n        optimizer.step()\r\n        \r\n        total_cost += loss\r\n\r\n    avg_cost = total_cost \/ total_batch\r\n    print(\"Epoch:\", \"%03d\" % (epoch+1), \"Cost =\", \"{:.9f}\".format(avg_cost))  \r\n    costs.append(avg_cost)              \r\n<\/pre>\n<p>\uc633\ubc14\ub974\uac8c \ud559\uc2b5\uc774 \ub41c\ub2e4\uba74 \uc5d0\ud3ed\uc774 \uc99d\uac00\ub420\ub54c\ub9c8\ub2e4 \uc190\uc2e4\uac12\uc740 \uc810\uc810 \uc904\uc5b4 \ub4e4\uac8c \ub418\ub294\ub370, X\ucd95\uc774 \uc5d0\ud3ed, Y\ucd95\uc774 \uc190\uc2e4(\ube44\uc6a9)\uc778 \uc544\ub798\uc758 \uadf8\ub798\ud504\uc758 \uacb0\uacfc\ub97c \ub3c4\ucd9c\ud588\uc2b5\ub2c8\ub2e4.<\/p>\n<p><img loading=\"lazy\" decoding=\"async\" src=\"http:\/\/www.gisdeveloper.co.kr\/wp-content\/uploads\/2019\/08\/Figure_1.png\" alt=\"\" width=\"1337\" height=\"734\" class=\"aligncenter size-full wp-image-7869\" \/><\/p>\n<p>\uc774\uc81c \ud559\uc2b5\ub41c \uac00\uc911\uce58\uc640 \ud3b8\ud5a5\uac12\uc744 \ud1b5\ud574 \ud14c\uc2a4\ud2b8\ub97c \ud558\ub294 \ucf54\ub4dc\ub294 \uc544\ub798\uc640 \uac19\uc2b5\ub2c8\ub2e4.<\/p>\n<pre class=\"EnlighterJSRAW\" data-enlighter-language=\"python\">\r\nmodel.eval()\r\nwith torch.no_grad():\r\n    correct = 0\r\n    total = 0\r\n    for i, (imgs, labels) in enumerate(test_loader):\r\n        imgs, labels = imgs.to(device), labels.to(device)\r\n        outputs = model(imgs)\r\n        _, argmax = torch.max(outputs, 1)\r\n        total += imgs.size(0)\r\n        correct += (labels == argmax).sum().item()\r\n    \r\n    print('Accuracy for {} images: {:.2f}%'.format(total, correct \/ total * 100))                \r\n<\/pre>\n<p>\uc815\ud655\ub3c4\ub294 89.39%\uac00 \ub3c4\ucd9c\ub418\uc5c8\ub294\ub370\uc694. \uc774 \uc815\ud655\ub3c4\ub97c \uc810\ub354 \uc2dc\uac01\uc801\uc73c\ub85c \uc778\uc9c0\ud558\uae30 \uc704\ud574 \ud14c\uc2a4\ud2b8 \ub370\uc774\ud130\uc5d0\uc11c 36\uac1c\uc758 \uc774\ubbf8\uc9c0\ub97c \ubf51\uc544 \uac01\uac01\uc758 \uc774\ubbf8\uc9c0\uac00 \uc5b4\ub5a4 \uac83\uc73c\ub85c \ubd84\ub958\ub418\uc5c8\ub294\uc9c0 \ud655\uc778\ud558\ub294 \ucf54\ub4dc\ub97c \uc791\uc131\ud574 \ubcf4\uba74 \ub2e4\uc74c\uacfc \uac19\uc2b5\ub2c8\ub2e4.<\/p>\n<pre class=\"EnlighterJSRAW\" data-enlighter-language=\"python\">\r\nlabel_tags = {\r\n    0: 'T-Shirt', \r\n    1: 'Trouser', \r\n    2: 'Pullover', \r\n    3: 'Dress', \r\n    4: 'Coat', \r\n    5: 'Sandal', \r\n    6: 'Shirt',\r\n    7: 'Sneaker', \r\n    8: 'Bag', \r\n    9: 'Ankle Boot'\r\n}\r\n\r\ncolumns = 6\r\nrows = 6\r\nfig = plt.figure(figsize=(10,10))\r\n \r\nmodel.eval()\r\nfor i in range(1, columns*rows+1):\r\n    data_idx = np.random.randint(len(test_data))\r\n    input_img = test_data[data_idx][0].unsqueeze(dim=0).to(device) \r\n \r\n    output = model(input_img)\r\n    _, argmax = torch.max(output, 1)\r\n    pred = label_tags[argmax.item()]\r\n    label = label_tags[test_data[data_idx][1]]\r\n    \r\n    fig.add_subplot(rows, columns, i)\r\n    if pred == label:\r\n        plt.title(pred + ', right !!')\r\n        cmap = 'Blues'\r\n    else:\r\n        plt.title('Not ' + pred + ' but ' +  label)\r\n        cmap = 'Reds'\r\n    plot_img = test_data[data_idx][0][0,:,:]\r\n    plt.imshow(plot_img, cmap=cmap)\r\n    plt.axis('off')\r\n    \r\nplt.show() \r\n<\/pre>\n<p>\uc704\uc758 10\ubc88 \ucf54\ub4dc\uc5d0\uc11c unsqueeze() \ud568\uc218\ub97c \uc0ac\uc6a9\ud55c \uac83\uc740 \uc6d0\ubcf8 \ub370\uc774\ud130\uc758 Shape\uac00 (1, 28, 28)\uc778\ub370, \uc774\ub97c \ubaa8\ub378\uc5d0 \uc785\ub825\ub418\ub294 \ub370\uc774\ud130\uc758 Shape\uc778 (1, 1, 28, 28)\ub85c \ubcc0\ud658\ud574\uc57c \ud558\uae30 \ub54c\ubb38\uc785\ub2c8\ub2e4. \uacb0\uacfc\ub294 \ub2e4\uc74c\uacfc \uac19\uc2b5\ub2c8\ub2e4.<\/p>\n<p><img loading=\"lazy\" decoding=\"async\" src=\"http:\/\/www.gisdeveloper.co.kr\/wp-content\/uploads\/2019\/08\/dnn_result.png\" alt=\"\" width=\"2058\" height=\"1426\" class=\"aligncenter size-full wp-image-8059\" \/><\/p>\n<p>\uc704\uc758 \uc774\ubbf8\uc9c0\ub97c \ubcf4\uba74 \ubd84\ub958\uac00 3\uac1c \uc815\ub3c4 \ud2c0\ub9b0 \uac83\uc73c\ub85c \ud45c\uc2dc\ub429\ub2c8\ub2e4. \uc8fc\ub85c T\uc154\uce20\ub97c \uadf8\ub0e5 \uc154\uce20\ub85c \ubd84\ub958\ud558\uac70\ub098 \uc2a4\uc6e8\ud130\ub97c \uc154\uce20\ub85c \ubd84\ub958\ud55c \uacbd\uc6b0\uc785\ub2c8\ub2e4. \uc774 \uacbd\uc6b0 CNN\uc73c\ub85c \ud559\uc2b5\ud558\uba74 \uc815\ud655\ub3c4\ub97c \ub354\uc6b1 \ud5a5\uc0c1\uc2dc\ud0ac \uc218 \uc788\uc2b5\ub2c8\ub2e4.<\/p>\n","protected":false},"excerpt":{"rendered":"<p>\ucc98\uc74c \ub525\ub7ec\ub2dd\uc744 \ud14c\uc2a4\ud2b8 \ud558\uae30 \uc704\ud574 \ud754\ud788 \uc0ac\uc6a9\ud558\ub294 \ub370\uc774\ud130\ub294 MNIST \uc785\ub2c8\ub2e4. 0~9\uae4c\uc9c0\uc758 \uc190\uae00\uc528\uc5d0 \ub300\ud55c 28&#215;28 \ud06c\uae30\uc758 \uc774\ubbf8\uc9c0\uc785\ub2c8\ub2e4. \uc774\ubbf8\uc9c0 \ub370\uc774\ud130\uc640 \ud568\uaed8 \ub77c\ubca8 \ub370\uc774\ud130\ub3c4 \uc81c\uacf5\ub418\ubbc0\ub85c \ubc14\ub85c \ud65c\uc6a9\ud560 \uc218 \uc788\ub294 \ub370\uc774\ud130\uc785\ub2c8\ub2e4. MNSIT\uc5d0 \uc601\uac10\uc744 \ubc1b\uc740 Fashion-MNIST\ub77c\ub294 \ub370\uc774\ud130\uac00 \uc874\uc7ac\ud569\ub2c8\ub2e4. \ucd1d 10\uac00\uc9c0\uc758 \ud328\uc158 \uc544\uc774\ud15c\uc5d0 \ub300\ud55c \uc774\ubbf8\uc9c0\uc640 \ub77c\ubca8\uc785\ub2c8\ub2e4. \uc81c\uacf5\ub418\ub294 \uc774\ubbf8\uc9c0\uc758 \uc608\ub294 \uc544\ub798\uc640 \uac19\uc2b5\ub2c8\ub2e4. \uc704 \uc774\ubbf8\uc9c0\ub4e4\uc5d0\uc11c \uc0c1\ub2e8\ubd80\ud130 3\uac1c\uc758 Rows\uc529\uc744 \ud55c \uadf8\ub8f9\uc73c\ub85c \ubb36\uc73c\uba74 \uac01\uac01\uc758 \uadf8\ub8f9\uc774 &hellip; <\/p>\n<p class=\"link-more\"><a href=\"http:\/\/www.gisdeveloper.co.kr\/?p=7846\" class=\"more-link\">\ub354 \ubcf4\uae30<span class=\"screen-reader-text\"> &#8220;DNN\uc744 \uc774\uc6a9\ud55c Fashion-MNIST \ub370\uc774\ud130\uc5d0 \ub300\ud55c Classifier&#8221;<\/span><\/a><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[132],"tags":[],"class_list":["post-7846","post","type-post","status-publish","format-standard","hentry","category-deep-machine-learning"],"_links":{"self":[{"href":"http:\/\/www.gisdeveloper.co.kr\/index.php?rest_route=\/wp\/v2\/posts\/7846","targetHints":{"allow":["GET"]}}],"collection":[{"href":"http:\/\/www.gisdeveloper.co.kr\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"http:\/\/www.gisdeveloper.co.kr\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"http:\/\/www.gisdeveloper.co.kr\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"http:\/\/www.gisdeveloper.co.kr\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=7846"}],"version-history":[{"count":35,"href":"http:\/\/www.gisdeveloper.co.kr\/index.php?rest_route=\/wp\/v2\/posts\/7846\/revisions"}],"predecessor-version":[{"id":9384,"href":"http:\/\/www.gisdeveloper.co.kr\/index.php?rest_route=\/wp\/v2\/posts\/7846\/revisions\/9384"}],"wp:attachment":[{"href":"http:\/\/www.gisdeveloper.co.kr\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=7846"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"http:\/\/www.gisdeveloper.co.kr\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=7846"},{"taxonomy":"post_tag","embeddable":true,"href":"http:\/\/www.gisdeveloper.co.kr\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=7846"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}