pytorch---仿射变换
一、仿射变换
图片的旋转、平移、缩放等可以看做一个像素的重采样过程。将原图的像素映射到目标图像的对应位置上,可以
![\\begin\{bmatrix\} x\\\\ y\\\\ 1 \\end\{bmatrix\} = \\begin\{bmatrix\} \{x\}^\{s\} & \{y\}^\{s\} & 1 \\end\{bmatrix\} \* \\begin\{bmatrix\} a & b &0 \\\\ c & d & 0\\\\ e &f & 1 \\end\{bmatrix\}][begin_bmatrix_ x_ y_ 1 _end_bmatrix_ _ _begin_bmatrix_ _x_s_ _ _y_s_ _ 1 _end_bmatrix_ _ _begin_bmatrix_ a _ b _0 _ c _ d _ 0_ e _f _ 1 _end_bmatrix]
其中为原图的坐标,x,y为目标图的坐标,该变换称为前向变换,遍历原图像素,求出改像素在目标图像的对应位置。
前向变换虽然符合逻辑,但是却使得目标图像上很多位置没有对应的像素。因此一种更合理的方式是使用后向变换,即从目标图像出发,遍历目标图像的每个位置,求出每个位置在原图中的对应像素。此时,公式变为:
![\\begin\{bmatrix\} x^\{s\}\\\\ y^\{s\}\\\\ 1 \\end\{bmatrix\} = \\begin\{bmatrix\} \{x\} & \{y\} & 1 \\end\{bmatrix\} \*\{ \\begin\{bmatrix\} a & b &0 \\\\ c & d & 0\\\\ e &f & 1 \\end\{bmatrix\}\}^\{-1\}][begin_bmatrix_ x_s_ y_s_ 1 _end_bmatrix_ _ _begin_bmatrix_ _x_ _ _y_ _ 1 _end_bmatrix_ _ _begin_bmatrix_ a _ b _0 _ c _ d _ 0_ e _f _ 1 _end_bmatrix_-1]
二、pytorch中的仿射变换
pytorch中就使用的为后向变换。主要涉及两个函数
- F.affine_grid(theta,size)
- F.grid_sample(input, grid, mode=’bilinear’, padding_mode=’zeros’)
1.F.affine_grid根据输入的变换矩阵theta和尺寸利用后向变换求出目标图像每个像素在原图像的位置。
theta是一个\[N,2,3\]的tensor,N为batchsize大小;2行3列共六个参数,为affine的变换矩阵,第一行为x坐标,即横坐标的变换参数,前两个为权重,最后一个为偏移,值得注意的是偏移值是一个相对于图像宽归一化的参数a,c,e(并非像素值),例如0.5表示左移半个图像的宽度。第二行表示y坐标的变换参数(b,d,f)。
size是一个tuple,为(N,C,H,W)
output为[N,h,w,2]的Tensor,表示在原图中的对应位置。
- F.grid_sample()为重采样函数,根据输入的原图和位置对应关系矩阵(F.affine_grid的输出)对原图像素进行重采样,构成变换后的图像。由于重采样过程中,在原图中的位置会出现小数,因此需要对原图进行插值,插值方式为可选参数,默认双线性插值。
下面我们来看一个例子:
将图像顺时针旋转45度,注意pytorch使用的为后向变换。
对于前向变换来说,顺时针旋转45度的变换矩阵为,后向变换应该对其求逆。但是我们可以换一个角度理解,原图到目标图需要顺时针旋转45度,那么目标图到原图不就是逆时针旋转45度吗,因此直接取
带入原公式计算即可
代码如下:
import torch
import cv2
import torch.nn.functional as F
import matplotlib.pyplot as plt
theta = torch.Tensor([[0.707,0.707,0],[-0.707,0.707,0]]).unsqueeze(dim=0)
img = cv2.imread('achor.png',cv2.IMREAD_GRAYSCALE)
plt.subplot(2,1,1)
plt.imshow(img,cmap='gray')
plt.axis('off')
img = torch.Tensor(img).unsqueeze(0).unsqueeze(0)
grid = F.affine_grid(theta,size=img.shape)
output = F.grid_sample(img,grid)[0].numpy().transpose(1,2,0).squeeze()
plt.subplot(2,1,2)
plt.imshow(output,cmap='gray')
plt.axis('off')
plt.show()
结果如下(pytorch中以图像中心点为原点,与一般的左上角为原点不太一样):
还没有评论,来说两句吧...