一文详解affine_grid 与 grid_sample以及与opencv坐标系的关系

发布于:2024-04-26 ⋅ 阅读:(23) ⋅ 点赞:(0)

前言

网上资料乱七八糟,本文通过坐标系和变换的角度,系统梳理两个操作的作用

基本仿射变换

二维仿射变换,我们可以综合为一个2x2的旋转矩阵R和一个2x1的平移矩阵t,[R,t]组合起来就是2x3的矩阵
我们可以增广为3x3的矩阵,只需最后一行加上0 0 1即可。

更细致描述请参阅games101

在这里插入图片描述

opencv 坐标系

-------------------->x
|
|
|
|
V y
opencv的坐标系是从左上角开始(0,0) 其中x坐标向右侧延申,y坐标向下延申的。
因此当我们按照想旋转和平移时,需要按照当前坐标系来理解,例如原先的正向现在为顺时针,向y轴平移现在就是向下平移。
我们如下代码给出了先旋转45°然后向x轴方向平移50个像素的坐标

import torch
import cv2
from skimage.transform._geometric import _umeyama as get_sym_mat
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
I = cv2.imread('test.png')
print(I.size)
rotate_theta = np.radians(45) # 转换为弧度
rotate_matrix = np.array([
        [np.cos(rotate_theta), -np.sin(rotate_theta), 50],
        [np.sin(rotate_theta), np.cos(rotate_theta), 0],
        [0, 0, 1]
    ])
print(rotate_matrix)
M=rotate_matrix[:2, :]
J = cv2.warpAffine(I, M,I.shape[:2])
cv2.imshow('show', J)
cv2.waitKey(0)

在这里插入图片描述
我们发现事实确实如我们所料

pytorch坐标系

pytorch的坐标是在-1到1的归一化坐标系下的值,例如左上角是(-1,-1),右上角是(1,1)
图片的中心在图片正中间

将已知的变换矩阵转换为pytorch的变换矩阵

STN网络中最常用的两个函数是
affine_gridgrid_sample,这两个函数组合使用可以完成类似的warp操作,而且支持批量操作。
我们通常的使用方法是
给定一个theta矩阵,通过theta矩阵变换一个网格grid,dist的每个点根据grid在src里插值填满整个图。

grid = F.affine_grid(theta, [1, C, H, W])
dist = F.grid_sample(src, grid)

一个非常自然的想法就是,我们希望也能像opencv一样能利用我们学过的知识,手动控制变化。
当我们获得了opencv坐标系下的一个仿射变换矩阵如下所示时,我们需要进行坐标转换
[ a 1 a 2 t x a 3 a 4 t y 0 0 1 ] \left[ \begin{array}{ccc} a_1 & a_2 & t_x \\ a_3 & a_4 & t_y \\ 0 & 0 & 1 \end{array} \right] a1a30a2a40txty1

如果你学过线性代数或者矩阵论的相关知识,就知道我们可以用相似变换来解决这种坐标系下的变换问题。
不了解的可以查看线性代数的本质-基变换
对于任意一个原图pytorch坐标系下的点 ( u 1 , v 1 ) (u_1,v_1) (u1,v1)经过变换后,即可找到新图中 ( u 2 , v 2 ) (u_2,v_2) (u2,v2)的点
[ u 2 v 2 1 ] = [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] [ a 1 a 2 t x a 3 a 4 t y 0 0 1 ] [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] − 1 [ u 1 v 1 1 ] \begin{bmatrix}u_2\\v_2\\1\end{bmatrix}=\begin{bmatrix}\frac{2}{W}&0&-1\\0&\frac{2}{H}&-1\\0&0&1\end{bmatrix}\begin{bmatrix}a_1&a_2&t_x\\a_3&a_4&t_y\\0&0&1\end{bmatrix}\begin{bmatrix}\frac{2}{W}&0&-1\\0&\frac{2}{H}&-1\\0&0&1\end{bmatrix}^{-1}\begin{bmatrix}u_1\\v_1\\1\end{bmatrix} u2v21 = W2000H20111 a1a30a2a40txty1 W2000H20111 1 u1v11

但是由于实际上theta控制着dst图去找src图,因此我们实际上的theta是上述变换的逆矩阵
θ = [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] [ a 1 a 2 t x a 3 a 4 t y 0 0 1 ] − 1 [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] − 1 \theta = \left[ \begin{array}{ccc} \frac{2}{W} & 0 & -1 \\ 0 & \frac{2}{H} & -1 \\ 0 & 0 & 1 \end{array} \right] \left[ \begin{array}{ccc} a_1 & a_2 & t_x \\ a_3 & a_4 & t_y \\ 0 & 0 & 1 \end{array} \right]^{-1} \left[ \begin{array}{ccc} \frac{2}{W} & 0 & -1 \\ 0 & \frac{2}{H} & -1 \\ 0 & 0 & 1 \end{array} \right]^{-1} θ= W2000H20111 a1a30a2a40txty1 1 W2000H20111 1

实战

给定一张宽高为W,H的图像,以及其中的某一块区域(x,y,w,h)意思是在原图x,y像素位置,有像素长度为w和h的框,请你将这个框截取出来并通过仿射变换为一个W和H大小的图片
我们依次按照theta矩阵填入

θ = [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] [ W / w 0 − x W / w 0 H / h − y H / h 0 0 1 ] − 1 [ 2 W 0 − 1 0 2 H − 1 0 0 1 ] − 1 \theta = \left[ \begin{array}{ccc} \frac{2}{W} & 0 & -1 \\ 0 & \frac{2}{H} & -1 \\ 0 & 0 & 1 \end{array} \right] \left[ \begin{array}{ccc} W/w & 0 & -xW/w \\ 0 & H/h & -yH/h \\ 0 & 0 & 1 \end{array} \right]^{-1} \left[ \begin{array}{ccc} \frac{2}{W} & 0 & -1 \\ 0 & \frac{2}{H} & -1 \\ 0 & 0 & 1 \end{array} \right]^{-1} θ= W2000H20111 W/w000H/h0xW/wyH/h1 1 W2000H20111 1
中间变换矩阵的含义是先将小图 x,y 的源点位置移动到0,0,随后进行尺度缩放。
θ = [ w / W 0 − 1 + 2 ∗ ( x + w / 2 ) / W 0 h / H − 1 + 2 ∗ ( y + h / 2 ) / H 0 0 1 ] \theta = \left[ \begin{array}{ccc} w/W & 0 & -1 + 2 * (x +w/2)/ W \\ 0 & h/H & -1 + 2 * (y +h/2)/ H \\ 0 & 0 & 1 \end{array} \right] θ= w/W000h/H01+2(x+w/2)/W1+2(y+h/2)/H1

import torch
import cv2
from skimage.transform._geometric import _umeyama as get_sym_mat
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
I = cv2.imread('test.png')
H, W, C = I.shape

It = torch.from_numpy(I).type(torch.float32).permute(2, 0, 1).unsqueeze(0)
x, y, w, h = [80,70,100,100]
scale_x = w / W
scale_y = h / H
translate_x = (-1 + 2 * (x +w/2)/ W) 
translate_y =(-1 + 2 * (y +h/2)/ H) 
# 构造 theta 矩阵
theta = torch.tensor([[scale_x, 0, translate_x],
                    [0, scale_y, translate_y]], dtype=torch.float).unsqueeze(0)

grid = F.affine_grid(theta, [1, C, H, W])
Jt = F.grid_sample(It, grid)
J = Jt.squeeze().permute(1, 2, 0).detach().numpy().astype('uint8')
cv2.imshow('show', J)
cv2.waitKey(0)

在这里插入图片描述

参考资料

https://www.zhihu.com/question/294673086


网站公告

今日签到

点亮在社区的每一天
去签到