19-卷积层

19-卷积层

本讲文字介绍部分请参考沐神在线书籍~:https://zh-v2.d2l.ai/chapter_convolutional-neural-networks/why-conv.html

代码

import torch
from torch import nn

def corr2d(X,K):    #X为输入,K为核矩阵
    h,w=K.shape    #h得到K的行数,w得到K的列数
    Y=torch.zeros((X.shape[0]-h+1,X.shape[1]-w+1))  #用0初始化输出矩阵Y
    for i in range(Y.shape[0]):   #卷积运算
        for j in range(Y.shape[1]):
          Y[i,j]=(X[i:i+h,j:j+w]*K).sum()
    return Y
#样例点测试
X=torch.tensor([[0,1,2],[3,4,5],[6,7,8]])
K=torch.tensor([[0,1],[2,3]])
corr2d(X,K)
>>> tensor([[19., 25.],
     	   [37., 43.]])
#实现二维卷积层
class Conv2d(nn.Module):
    def _init_(self,kernel_size):
        super()._init_()
        self.weight=nn.Parameter(torch.rand(kerner_size))
        self.bias=nn.Parameter(torch.zeros(1))
    def forward(self,x):
        return corr2d(x,self.weight)+self.bias 
X=torch.ones((6,8))
X[:,2:6]=0
X
>>> tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])
K=torch.tensor([[-1,1]])  #这个K只能检测垂直边缘
Y=corr2d(X,K)
Y
>>> tensor([[ 0., -1.,  0.,  0.,  0.,  1.,  0.],
            [ 0., -1.,  0.,  0.,  0.,  1.,  0.],
            [ 0., -1.,  0.,  0.,  0.,  1.,  0.],
            [ 0., -1.,  0.,  0.,  0.,  1.,  0.],
            [ 0., -1.,  0.,  0.,  0.,  1.,  0.],
            [ 0., -1.,  0.,  0.,  0.,  1.,  0.]])
corr2d(X.t(),K)
>>> tensor([[0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.]])
conv2d = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)

X = X.reshape((1, 1, 6, 8))
Y = Y.reshape((1, 1, 6, 7))

for i in range(10):
    Y_hat = conv2d(X)
    l = (Y_hat - Y)**2
    conv2d.zero_grad()
    l.sum().backward()
    conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad
    if (i + 1) % 2 == 0:
        print(f'batch {i+1}, loss {l.sum():.3f}')
>>> batch 2, loss 3.852
    batch 4, loss 1.126
    batch 6, loss 0.386
    batch 8, loss 0.145
    batch 10, loss 0.057
conv2d.weight.data.reshape((1, 2))
>>> tensor([[-1.0173,  0.9685]])
上一页
下一页