本文共 2884 字,大约阅读时间需要 9 分钟。
莫烦pytorch学习记录--cnn
#cnn.pyimport osimport timeimport numpy as npimport torchimport torchvisionimport torch.nn as nnfrom torch.utils.data import DataLoaderfrom torch.utils import dataimport torch.utils.data as Datafrom torch.autograd import Variableimport matplotlib.pyplot as pltimport pdbEPOCH=1BATCH_SIZE=50LR=0.001DOWNLOAD_MNIST=Falsetrain_data = torchvision.datasets.MNIST( root='../catVsDog/data', train=True, transform=torchvision.transforms.ToTensor(),#0-255==>0-1 download=DOWNLOAD_MNIST )# print(train_data.train_data.size())# print(train_data.train_labels.size())# plt.imshow(train_data.train_data[0].numpy(), cmap='gray')# plt.title('%i'%train_data.train_labels[0])# plt.show()train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)test_data = torchvision.datasets.MNIST(root='../catVsDog/data',train=False)# plt.imshow(test_data.test_data[0].numpy(), cmap='gray')# plt.title('%i'%test_data.test_labels[0])# plt.show()test_x=Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.test_y = test_data.test_labels[:2000]class CNN(nn.Module): def __init__(self): super(CNN,self).__init__() self.conv1 = nn.Sequential( nn.Conv2d( #(1,28,28) in_channels =1, out_channels=16, kernel_size=5, stride=1, padding=2, ), #(16,28,28) nn.ReLU(), nn.MaxPool2d(kernel_size=2),#(16,14,14) ) self.conv2 = nn.Sequential( nn.Conv2d(16, 32, 5, 1, 2),#(32,14,14) nn.ReLU(), nn.MaxPool2d(2)#(32,7,7) ) self.out = nn.Linear(32*7*7,10) def forward(self,x): x=self.conv1(x) x=self.conv2(x) #(batch, 32,7,7) x=x.view(x.size(0),-1)#展平(batch, 32*7*7) output = self.out(x) return outputcnn=CNN()# print(cnn)optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)loss_func = nn.CrossEntropyLoss()for epoch in range(EPOCH): for step, (x,y) in enumerate(train_loader): b_x = Variable(x) b_y = Variable(y) output =cnn(b_x) loss=loss_func(output, b_y) optimizer.zero_grad() loss.backward() optimizer.step() if(step%50 ==0): test_output = cnn(test_x) pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() #pdb.set_trace() test_y2 = test_y.data.numpy()#pred_y和test_y都要转成numpy()不然识别率一直是0 accuracy = sum(pred_y == test_y2) / test_y.size(0) print('Epoch: ',epoch, '| train loss: %.4f' % loss.item(), '|test accuracy: %.2f' %accuracy)test_output = cnn(test_x[:10])pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()print(pred_y, 'prediction number')print(test_y[:10].numpy(), 'real number')
转载地址:http://mwksi.baihongyu.com/