NumPy数组的变形(改变数组形状)

  • 内容
  • 评论
  • 相关

在机器学习以及深度学习的任务中,通常需要将处理好的数据以模型能接收的格式输入给模型,然后由模型通过一系列的运算,最终返回一个处理结果。然而,由于不同模型所接收的输入格式不一样,往往需要先对其进行一系列的变形和运算,从而将数据处理成符合模型要求的格式。

在矩阵或者数组的运算中,经常会遇到需要把多个向量或矩阵按某轴方向合并,或展平(如在卷积或循环神经网络中,在全连接层之前,需要把矩阵展平)的情况。下面介绍几种常用的数组变形方法。

修改指定数组的形状是 NumPy 中最常见的操作之一,常见的方法有很多,下表列出了一些常用函数和属性。

表1:Numpy 中改变向量形状的一些函数和属性
函数/属性 描述
arr.reshape() 重新将向量 arr 维度进行改变,不修改向量本身
arr.resize() 重新将向量 arr 维度进行改变,修改向量本身
arr.T 对向量 arr 进行转置
arr.ravel() 对向量 arr 进行展平,即将多维数组变成1维数组,不会产生原数组的副本
arr.flatten() 对向量 arr 进行展平,即将多维数组变成1维数组,返回原数组的副本
arr.squeeze() 只能对维数为1的维度降维。对多维数组使用时不会报错,但是不会产生任何影响
arr.transpose() 对高维矩阵进行轴对换

下面来看一些示例。

reshape() 函数

reshape() 函数用来改变向量的维度(不修改向量本身),请看下面的代码:

import numpy as np

arr =np.arange(10)
print(arr)
# 将向量 arr 维度变换为2行5列
print(arr.reshape(2, 5))
# 指定维度时可以只指定行数或列数, 其他用 -1 代替
print(arr.reshape(5, -1))
print(arr.reshape(-1, 5))

输出结果:

[0 1 2 3 4 5 6 7 8 9]
[[0 1 2 3 4]
  [5 6 7 8 9]]
[[0 1]
  [2 3]
  [4 5]
  [6 7]
  [8 9]]
[[0 1 2 3 4]
  [5 6 7 8 9]]

值得注意的是,reshape() 函数不支持指定行数或列数,所以 -1 在这里是必要的。且所指定的行数或列数一定要能被整除,例如上面代码如果修改为 arr.reshape(3,-1) 即为错误的。

resize() 函数

resize() 函数用来改变向量的维度(修改向量本身),请看下面的代码:

import numpy as np

arr =np.arange(10)
print(arr)
# 将向量 arr 维度变换为2行5列
arr.resize(2, 5)
print(arr)

输出结果:

[0 1 2 3 4 5 6 7 8 9]
[[0 1 2 3 4]
  [5 6 7 8 9]]

T 属性

T 属性用来对向量进行转置,请看下面的的代码:

import numpy as np

arr =np.arange(12).reshape(3,4)
# 向量 arr 为3行4列
print(arr)
# 将向量 arr 进行转置为4行3列
print(arr.T)

输出结果:

[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]
[[ 0  4  8]
  [ 1  5  9]
  [ 2  6 10]
  [ 3  7 11]]

ravel() 函数

ravel() 函数用于向量展平,请看下面的代码:

import numpy as np

arr =np.arange(6).reshape(2, -1)
print(arr)
# 按照列优先, 展平
print("按照列优先, 展平")
print(arr.ravel('F'))
# 按照行优先, 展平
print("按照行优先, 展平")
print(arr.ravel())

输出结果:

[[0 1 2]
  [3 4 5]]
按照列优先,展平
[0 3 1 4 2 5]
按照行优先,展平
[0 1 2 3 4 5]

flatten() 函数

flatten() 函数用来把矩阵转换为向量,这种需求经常出现在卷积网络与全连接层之间。

请看下面的代码:

import numpy as np
a =np.floor(10*np.random.random((3,4)))
print(a)
print(a.flatten())

输出结果:

[[4. 0. 8. 5.]
  [1. 0. 4. 8.]
  [8. 2. 3. 7.]]
[4. 0. 8. 5. 1. 0. 4. 8. 8. 2. 3. 7.]

squeeze() 函数

这是一个主要用来降维的函数,把矩阵中含1的维度去掉,请看下面的代码:

import numpy as np

arr =np.arange(3).reshape(3, 1)
print(arr.shape)  #(3,1)
print(arr.squeeze().shape)  #(3,)
arr1 =np.arange(6).reshape(3,1,2,1)
print(arr1.shape) #(3, 1, 2, 1)
print(arr1.squeeze().shape) #(3, 2)

transpose() 函数

对高维矩阵进行轴对换,这个在深度学习中经常使用,比如把图片中表示颜色顺序的 RGB 改为 GBR。

请看下面的代码:

import numpy as np

arr2 = np.arange(24).reshape(2,3,4)
print(arr2.shape)  #(2, 3, 4)
print(arr2.transpose(1,2,0).shape)  #(3, 4, 2)

本文标题:NumPy数组的变形(改变数组形状)

本文地址:https://www.hosteonscn.com/7801.html

评论

0条评论

发表评论

邮箱地址不会被公开。 必填项已用*标注