博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch(6)
阅读量:4099 次
发布时间:2019-05-25

本文共 2884 字,大约阅读时间需要 9 分钟。

莫烦pytorch学习记录--cnn

#cnn.pyimport osimport timeimport numpy as npimport torchimport torchvisionimport torch.nn as nnfrom torch.utils.data import DataLoaderfrom torch.utils import dataimport torch.utils.data as Datafrom torch.autograd import Variableimport matplotlib.pyplot as pltimport pdbEPOCH=1BATCH_SIZE=50LR=0.001DOWNLOAD_MNIST=Falsetrain_data = torchvision.datasets.MNIST(    root='../catVsDog/data',    train=True,    transform=torchvision.transforms.ToTensor(),#0-255==>0-1    download=DOWNLOAD_MNIST    )# print(train_data.train_data.size())# print(train_data.train_labels.size())# plt.imshow(train_data.train_data[0].numpy(), cmap='gray')# plt.title('%i'%train_data.train_labels[0])# plt.show()train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)test_data = torchvision.datasets.MNIST(root='../catVsDog/data',train=False)# plt.imshow(test_data.test_data[0].numpy(), cmap='gray')# plt.title('%i'%test_data.test_labels[0])# plt.show()test_x=Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.test_y = test_data.test_labels[:2000]class CNN(nn.Module):    def __init__(self):        super(CNN,self).__init__()        self.conv1 = nn.Sequential(            nn.Conv2d(          #(1,28,28)                in_channels =1,                out_channels=16,                kernel_size=5,                stride=1,                padding=2,                ),           #(16,28,28)            nn.ReLU(),            nn.MaxPool2d(kernel_size=2),#(16,14,14)            )        self.conv2 = nn.Sequential(            nn.Conv2d(16, 32, 5, 1, 2),#(32,14,14)            nn.ReLU(),            nn.MaxPool2d(2)#(32,7,7)            )        self.out = nn.Linear(32*7*7,10)    def forward(self,x):        x=self.conv1(x)        x=self.conv2(x)          #(batch, 32,7,7)        x=x.view(x.size(0),-1)#展平(batch, 32*7*7)        output = self.out(x)        return outputcnn=CNN()# print(cnn)optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)loss_func = nn.CrossEntropyLoss()for epoch in range(EPOCH):    for step, (x,y) in enumerate(train_loader):        b_x = Variable(x)        b_y = Variable(y)        output =cnn(b_x)        loss=loss_func(output, b_y)        optimizer.zero_grad()        loss.backward()        optimizer.step()        if(step%50 ==0):            test_output = cnn(test_x)            pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()            #pdb.set_trace()            test_y2 = test_y.data.numpy()#pred_y和test_y都要转成numpy()不然识别率一直是0            accuracy = sum(pred_y == test_y2) / test_y.size(0)            print('Epoch: ',epoch, '| train loss: %.4f' % loss.item(), '|test accuracy: %.2f' %accuracy)test_output = cnn(test_x[:10])pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()print(pred_y, 'prediction number')print(test_y[:10].numpy(), 'real number')

 

转载地址:http://mwksi.baihongyu.com/

你可能感兴趣的文章
设计模式之策略模式
查看>>
深究Java中的RMI底层原理
查看>>
用idea创建一个maven web项目
查看>>
Kafka
查看>>
9.1 为我们的角色划分权限
查看>>
维吉尼亚之加解密及破解
查看>>
DES加解密
查看>>
TCP/IP协议三次握手与四次握手流程解析
查看>>
PHP 扩展开发 : 编写一个hello world !
查看>>
inet_ntoa、 inet_aton、inet_addr
查看>>
用模板写单链表
查看>>
用模板写单链表
查看>>
链表各类操作详解
查看>>
C++实现 简单 单链表
查看>>
数据结构之单链表——C++模板类实现
查看>>
Linux的SOCKET编程 简单演示
查看>>
正则匹配函数
查看>>
Linux并发服务器编程之多线程并发服务器
查看>>
聊聊gcc参数中的-I, -L和-l
查看>>
[C++基础]034_C++模板编程里的主版本模板类、全特化、偏特化(C++ Type Traits)
查看>>