xarray 数据处理教程:从入门到精通
一、简介
xarray
是 Python 中用于处理多维数组数据的库,特别适用于带有标签(坐标)的科学数据(如气象、海洋、遥感等)。它基于 NumPy 和 Pandas,支持高效的数据操作、分析和可视化。
核心优势
- 标签化操作:通过维度名和坐标直接访问数据,无需记忆索引位置。
- 多维支持:天然支持多维数组(如时间、纬度、经度)。
- 集成工具:内置 NetCDF、HDF5 等格式读写,支持 Dask 处理大文件。
- 可视化:与 Matplotlib 深度集成,简化数据绘图流程。
二、安装与导入
1. 安装
pip install xarray netCDF4 dask rioxarray
或使用 Conda:
conda install -c conda-forge xarray netCDF4 dask rioxarray
2. 导入库
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
三、数据结构
(一)DataArray
- 定义:带坐标的 N 维数组,类似带标签的 NumPy 数组。
- 示例代码
import xarray as xr
import numpy as np
data = np.random.rand(12, 5, 100, 200)
coords = {
'time': np.arange(12),
'sample': np.arange(5),
'lat': np.linspace(-90, 90, 100),
'lon': np.linspace(-180, 180, 200)
}
da = xr.DataArray(data, dims=['time', 'sample', 'lat', 'lon'], coords=coords)
- 输出结果
<xarray.DataArray (time: 12, sample: 5, lat: 100, lon: 200)>
array([[[[...]], # 12个时间步 × 5个样本 × 100纬度 × 200经度的随机值
...,
[[...]]],
...,
[[...]]])
Coordinates:
* time (time) int64 0 1 2 ... 11
* sample (sample) int64 0 1 2 3 4
* lat (lat) float64 -90.0 -89.1 -88.2 ... 88.2 89.1 90.0
* lon (lon) float64 -180.0 -179.1 -178.2 ... 178.2 179.1 180.0
- 表1:数据结构DataArray
操作/方法 | 功能 | 输入参数 | 输出参数 |
---|---|---|---|
xarray.DataArray |
创建带有维度和坐标的多维数组(如海温数据) | data : 数组数据;dims : 维度名列表;coords : 坐标字典 |
生成的 xarray.DataArray 对象 |
(二) Dataset
定义:类似字典的容器,包含多个
DataArray
(变量),共享坐标。表:xarray.Dataset 操作总结
操作/方法 | 功能 | 输入参数 | 输出参数 |
---|---|---|---|
xr.Dataset |
创建多变量数据集 | 变量字典、坐标字典 | xarray.Dataset 对象 |
.sel() / .isel() |
按标签或索引选择数据 | 维度名和值 | 子数据集或子数组 |
.mean() / .std() |
计算维度统计量 | 需求平均的维度名 | 统计后的数据集或数据数组 |
.to_netcdf() |
保存为 NetCDF 文件 | 文件路径和写入模式 | 无返回值(保存文件) |
xr.open_dataset() |
读取 NetCDF 文件 | 文件路径和读取引擎 | xarray.Dataset 对象 |
.merge() |
合并两个数据集 | 另一个 Dataset 对象 |
合并后的 xarray.Dataset |
.apply() |
应用自定义函数 | 自定义函数和输入维度 | 应用后的数据集 |
.chunk() |
设置数据分块 | 分块大小字典 | 分块后的 xarray.Dataset |
- 创建示例:
import xarray as xr
import numpy as np
import pandas as pd
ds = xr.Dataset(
{
"temperature": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
"humidity": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
},
coords={
"time": pd.date_range("2025-01-01", periods=3),
"lat": np.linspace(-90, 90, 10),
"lon": np.linspace(-180, 180, 20),
}
)
- 输出结果:
<xarray.Dataset>
Dimensions: (time: 3, lat: 10, lon: 20)
Coordinates:
* time (time) datetime64[ns] 2025-01-01 2025-01-02 ...
* lat (lat) float64 -90.0 -81.0 ... 81.0 90.0
* lon (lon) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
temperature (time, lat, lon) float64 0.1234 0.5678 ...
humidity (time, lat, lon) float64 0.9876 0.4321 ...
(三)关键说明
- Dataset vs DataArray:
xarray.Dataset
适合处理多变量数据(如温度、湿度、降水)。xarray.DataArray
适合单一变量的多维数组操作。
- 工作流示例:
# 读取数据并选择子集 ds = xr.open_dataset("data.nc") subset = ds.sel(lat=slice(-90, -60), lon=slice(-180, -120)) # 计算统计量并保存 mean_ds = subset.mean(dim="time") mean_ds.to_netcdf("mean_data.nc")
四、数据操作
(一)索引与切片
表:索引与切片
操作/方法 | 功能 | 输入参数 | 输出参数 |
---|---|---|---|
.isel() / .sel() |
快速提取特定维度数据 | isel : 按索引提取;sel : 按坐标标签提取 |
提取后的子数组(xarray.DataArray ) |
1. 基于标签选择(.sel()
)
场景:从数据集中提取特定时间步和纬度范围的数据。
subset = ds.sel(time="2025-01-01", lat=-90)
输出结果:
<xarray.Dataset>
Dimensions: (lat: 1, lon: 20)
Coordinates:
time datetime64[ns] 2025-01-01
lat float64 -90.0
lon (lon) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
temperature (lon) float64 0.1234 0.5678 ...
humidity (lon) float64 0.9876 0.4321 ...
2. 基于位置选择(.isel()
)
场景:从数据集中提取第一个时间步和前三个纬度的数据。
# 按索引位置选择数据
subset_isel = ds.isel(time=0, lat=slice(0, 3))
输出结果:
<xarray.Dataset>
Dimensions: (time: 1, lat: 3, lon: 20)
Coordinates:
time datetime64[ns] 2025-01-01
* lat (lat) float64 -90.0 -76.36 -62.73
* lon (lon) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
temperature (time, lat, lon) float64 0.1234 0.5678 ...
3. .sel()
和 .isel()
联合使用
场景:结合标签和索引选择数据(例如,选择特定时间步和固定经度索引)。
# 按时间标签和经度索引选择数据
subset_mixed = ds.sel(time="2025-01-01").isel(lon=0)
输出结果:
<xarray.Dataset>
Dimensions: (time: 1, lat: 10, lon: 1)
Coordinates:
time datetime64[ns] 2025-01-01
* lat (lat) float64 -90.0 -81.0 ... 81.0 90.0
lon float64 -180.0
Data variables:
temperature (time, lat, lon) float64 0.1234 0.5678 ...
4. 多维选择与切片
场景:同时选择多个维度(时间、纬度、经度)并使用切片操作。
# 按时间标签、纬度范围和经度切片选择数据
subset_slice = ds.sel(
time="2025-01-01",
lat=slice(-90, -60),
lon=slice(-180, -120)
)
输出结果:
<xarray.Dataset>
Dimensions: (time: 1, lat: 3, lon: 7)
Coordinates:
time datetime64[ns] 2025-01-01
* lat (lat) float64 -90.0 -76.36 -62.73
* lon (lon) float64 -180.0 -163.6 ... -126.3 -120.0
Data variables:
temperature (time, lat, lon) float64 0.1234 0.5678 ...
5. 关键说明
.sel()
vs.isel()
:.sel()
:使用坐标标签(如time="2025-01-01"
、lat=-90
)进行选择,适合已知具体坐标的场景。.isel()
:使用索引位置(如time=0
、lat=slice(0, 3)
)进行选择,适合已知数组索引的场景。
切片操作:
- 可以通过
slice(start, end)
实现对维度的范围选择(例如lat=slice(-90, -60)
)。 - 切片是左闭右开的,即包含
start
,不包含end
。
- 可以通过
多维联合选择:
- 可以联合使用
.sel()
和.isel()
,例如先按标签选择时间,再按索引选择经度。 - 也可以通过链式调用实现多步选择(如
ds.sel(...).isel(...)
)。
- 可以联合使用
(二) 数据计算
1. 聚合运算
(1) 计算单个维度的平均值
场景:从数据集中计算时间维度(time
)的平均值。
import xarray as xr
import numpy as np
import pandas as pd
# 创建示例数据集
ds = xr.Dataset(
{
"temperature": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
"humidity": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
},
coords={
"time": pd.date_range("2025-01-01", periods=3),
"lat": np.linspace(-90, 90, 10),
"lon": np.linspace(-180, 180, 20),
}
)
# 计算时间维度的平均值
mean_time = ds.mean(dim="time")
输出结果:
<xarray.Dataset>
Dimensions: (lat: 10, lon: 20)
Coordinates:
* lat (lat) float64 -90.0 -81.0 -72.0 ... 72.0 81.0 90.0
* lon (lon) float64 -180.0 -171.0 -162.0 ... 162.0 171.0 180.0
Data variables:
temperature (lat, lon) float64 0.4567 0.8901 ...
humidity (lat, lon) float64 0.7654 0.3210 ...
(2) 计算多个维度的平均值
场景:从数据集中计算时间和纬度维度的平均值。
# 计算时间和纬度维度的平均值
mean_time_lat = ds.mean(dim=["time", "lat"])
输出结果:
<xarray.Dataset>
Dimensions: (lon: 20)
Coordinates:
* lon (lon) float64 -180.0 -171.0 -162.0 ... 162.0 171.0 180.0
Data variables:
temperature (lon) float64 0.6789 ...
humidity (lon) float64 0.5432 ...
(3) 计算单个维度的标准差
场景:从数据集中计算纬度维度(lat
)的标准差。
# 计算纬度维度的标准差
std_lat = ds.std(dim="lat")
输出结果:
<xarray.Dataset>
Dimensions: (time: 3, lon: 20)
Coordinates:
* time (time) datetime64[ns] 2025-01-01 ... 2025-01-03
* lon (lon) float64 -180.0 -171.0 -162.0 ... 162.0 171.0 180.0
Data variables:
temperature (time, lon) float64 0.2345 0.6789 ...
humidity (time, lon) float64 0.3456 0.7890 ...
(4) 计算多个维度的标准差
场景:从数据集中计算时间和经度维度的标准差。
# 计算时间和经度维度的标准差
std_time_lon = ds.std(dim=["time", "lon"])
输出结果:
<xarray.Dataset>
Dimensions: (lat: 10)
Coordinates:
* lat (lat) float64 -90.0 -81.0 -72.0 ... 72.0 81.0 90.0
Data variables:
temperature (lat) float64 0.1234 ...
humidity (lat) float64 0.4567 ...
(5) 忽略缺失值计算统计量
场景:数据集中包含缺失值(NaN
),需要在计算时跳过缺失值。
# 创建包含缺失值的数据集
ds_nan = xr.Dataset(
{
"temperature": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
},
coords={
"time": pd.date_range("2025-01-01", periods=3),
"lat": np.linspace(-90, 90, 10),
"lon": np.linspace(-180, 180, 20),
}
)
# 随机插入缺失值
ds_nan.temperature.values[0, 0, 0] = np.nan
# 计算时间维度的平均值(跳过缺失值)
mean_time_skipna = ds_nan.mean(dim="time", skipna=True)
输出结果:
<xarray.Dataset>
Dimensions: (lat: 10, lon: 20)
Coordinates:
* lat (lat) float64 -90.0 -81.0 -72.0 ... 72.0 81.0 90.0
* lon (lon) float64 -180.0 -171.0 -162.0 ... 162.0 171.0 180.0
Data variables:
temperature (lat, lon) float64 0.4567 0.8901 ...
(6) 关键说明
.mean()
和.std()
的区别:.mean()
:计算指定维度的平均值。.std()
:计算指定维度的标准差,默认为样本标准差(ddof=1
)。
维度选择:
- 可以指定单个维度(如
dim="time"
)或多个维度(如dim=["time", "lat"]
)。 - 维度减少后,输出数据集的维度会相应调整(如从
(time, lat, lon)
变为(lat, lon)
)。
- 可以指定单个维度(如
缺失值处理:
- 通过
skipna=True
可以跳过缺失值(NaN
)进行计算,避免因缺失值导致整个统计结果为NaN
。
- 通过
2. 算术运算
# 温度乘以 2,降水加 10
new_ds = ds * 2 + 10
3. 数据重塑
总结表格
操作/方法 | 功能 | 输入参数 | 输出参数 |
---|---|---|---|
.stack() |
将多个维度堆叠成一个新维度 | new_dim_name : 新维度名;dim : 原始维度列表 |
维度被堆叠后的 xarray.Dataset |
.transpose() |
调整维度顺序(不改变维度数量) | *dims : 新维度顺序 |
维度顺序调整后的 xarray.Dataset |
(1) .stack()
:将多个维度堆叠成一个新维度
场景:将 lat
和 lon
维度堆叠成一个名为 space
的新维度。
示例代码
import xarray as xr
import numpy as np
import pandas as pd
# 创建示例数据集
ds = xr.Dataset(
{
"temperature": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
},
coords={
"time": pd.date_range("2025-01-01", periods=3),
"lat": np.linspace(-90, 90, 10),
"lon": np.linspace(-180, 180, 20),
}
)
# 将 lat 和 lon 堆叠成 space 维度
stacked_ds = ds.stack(space=["lat", "lon"])
输出结果:
<xarray.Dataset>
Dimensions: (time: 3, space: 200)
Coordinates:
* time (time) datetime64[ns] 2025-01-01 2025-01-02 ...
space (space) MultiIndex
- lat (space) float64 -90.0 -90.0 ... 90.0 90.0
- lon (space) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
temperature (time, space) float64 0.1234 0.5678 ...
(2) .transpose()
:调整维度顺序
场景:将 time
, lat
, lon
的维度顺序调整为 lon
, lat
, time
。
示例代码:
# 调整维度顺序
transposed_ds = ds.transpose("lon", "lat", "time")
输出结果:
<xarray.Dataset>
Dimensions: (lon: 20, lat: 10, time: 3)
Coordinates:
* lon (lon) float64 -180.0 -171.0 ... 171.0 180.0
* lat (lat) float64 -90.0 -81.0 ... 81.0 90.0
* time (time) datetime64[ns] 2025-01-01 2025-01-02 ...
Data variables:
temperature (lon, lat, time) float64 0.1234 0.5678 ...
(3) .stack()
+ .transpose()
联合使用
场景:先将 lat
和 lon
堆叠成 space
,再调整维度顺序为 space
, time
。
示例代码:
# 堆叠后调整维度顺序
stacked_transposed_ds = ds.stack(space=["lat", "lon"]).transpose("space", "time")
输出结果:
<xarray.Dataset>
Dimensions: (space: 200, time: 3)
Coordinates:
space (space) MultiIndex
- lat (space) float64 -90.0 -90.0 ... 90.0 90.0
- lon (space) float64 -180.0 -171.0 ... 171.0 180.0
* time (time) datetime64[ns] 2025-01-01 2025-01-02 ...
Data variables:
temperature (space, time) float64 0.1234 0.5678 ...
(4) 关键说明
.stack()
vs.transpose()
:.stack()
:减少维度数量,将多个维度合并为一个新维度(如lat
和lon
→space
)。.transpose()
:不改变维度数量,仅调整维度的排列顺序(如time, lat, lon
→lon, lat, time
)。
应用场景:
.stack()
:- 将多维数据转换为二维,便于进行某些计算(如机器学习模型输入)。
- 简化高维数据的可视化(如将
lat
和lon
合并为space
后绘图)。
.transpose()
:- 调整数据维度顺序以匹配其他数据集或模型的输入格式。
- 提高代码可读性,使维度顺序更符合逻辑(如先经度后纬度)。
注意事项:
.stack()
会生成MultiIndex
,可通过.unstack()
恢复原始维度。.transpose()
不会修改原始数据,而是返回一个新对象(惰性操作)。
4. 数据聚合
操作/方法 | 功能 | 输入参数 | 输出参数 |
---|---|---|---|
.groupby() / .mean() |
按维度分组计算均值 | group : 分组维度(如 time.month );dim : 聚合维度 |
聚合后的 xarray.DataArray |
.resample() |
时间序列重采样(如日→月) | freq : 重采样频率(如 MS 表示月初);dim : 时间维度名 |
重采样后的 xarray.DataArray |
示例代码
# 按月份分组计算均值
monthly_mean = da.groupby("time.month").mean(dim="time")
# 时间序列重采样(日→月)
monthly_resample = da.resample(time="MS").mean()
5. 合并两个数据集
ds2 = xr.Dataset({"precipitation": (["time", "lat", "lon"], np.random.rand(3, 10, 20))})
ds_merged = ds.merge(ds2)
输出结果:
<xarray.Dataset>
Dimensions: (time: 3, lat: 10, lon: 20)
Coordinates:
* time (time) datetime64[ns] 2025-01-01 2025-01-02 ...
* lat (lat) float64 -90.0 -81.0 ... 81.0 90.0
* lon (lon) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
temperature (time, lat, lon) float64 0.1234 0.5678 ...
humidity (time, lat, lon) float64 0.9876 0.4321 ...
precipitation (time, lat, lon) float64 0.3456 0.7890 ...
6. 应用自定义函数
def custom_func(arr):
return arr.max() - arr.min()
ds_custom = ds.apply(custom_func)
输出结果:
<xarray.Dataset>
Dimensions: (time: 3, lat: 10, lon: 20)
Coordinates:
* time (time) datetime64[ns] 2025-01-01 2025-01-02 ...
* lat (lat) float64 -90.0 -81.0 ... 81.0 90.0
* lon (lon) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
temperature (time, lat, lon) float64 0.4321 0.8765 ...
humidity (time, lat, lon) float64 0.5432 0.9876 ...
五、数据可视化
表:可视化方法
操作/方法 | 功能 | 输入参数 | 输出参数 |
---|---|---|---|
.plot() |
快速可视化(等值线图、色阶图) | x , y : 维度名;cbar_kwargs : 颜色条参数;transform : 投影转换 |
matplotlib.axes.Axes 对象 |
.plot.scatter() |
散点图可视化 | x , y : 维度名;c : 颜色变量;size : 点大小 |
matplotlib.axes.Axes 对象 |
- 示例代码
# 绘制等值线图
da.plot.contourf(x="lon", y="lat", cmap="viridis")
# 绘制散点图
da.plot.scatter(x="lon", y="lat", c="temperature", size="precipitation")
(一)二维分布图
# 绘制温度的空间分布
ds["temp"].isel(time=0).plot(cmap="viridis")
plt.title("Temperature Distribution")
plt.show()
(二)时间序列图
# 绘制单个网格点的时间序列
ds["temp"].sel(lat=40, lon=100).plot.line(x="time")
plt.title("Temperature Time Series")
plt.show()
六、高级功能
(一) 缺失值处理
1. 填充缺失值
# 用 0 填充缺失值
filled_temp = ds["temp"].fillna(0)
2. 插值
# 使用线性插值填充缺失值
interpolated = ds["temp"].interpolate_na(dim="lat", method="linear")
(二) 时间重采样
# 将日数据重采样为月均值
monthly_mean = ds.resample(time="1M").mean()
(三)地理信息处理
1. 设置坐标系
import rioxarray
ds.rio.write_crs("EPSG:4326", inplace=True) # 设置为 WGS84 坐标系
2. 绘制地理投影图
import cartopy.crs as ccrs
ax = plt.axes(projection=ccrs.PlateCarree())
ds["temp"].isel(time=0).plot(ax=ax, transform=ccrs.PlateCarree())
ax.coastlines()
plt.show()
七、数据输入与输出
表:NetCDF 读取与保存
操作/方法 | 功能 | 输入参数 | 输出参数 |
---|---|---|---|
xr.open_dataset |
读取单个 NetCDF 文件 | filename : 文件路径;engine : 读取引擎(如 netcdf4 ) |
xarray.Dataset 对象 |
xr.to_netcdf |
保存数据为 NetCDF 文件 | filename : 保存路径;mode : 写入模式(如 w 表示覆盖) |
无返回值(直接写入文件) |
xr.open_mfdataset |
批量读取多文件数据集 | paths : 文件路径列表;engine : 读取引擎(如 h5netcdf );parallel : 是否并行读取;preprocess : 预处理函数 |
合并后的 xarray.Dataset |
xr.save_mfdataset |
批量保存数据集到文件 | datasets : 数据集列表;paths : 保存路径列表;encoding : 变量编码参数(如压缩设置) |
无返回值(直接写入文件) |
(一)读取 NetCDF 文件
ds = xr.open_dataset("data.nc") # 读取单个文件
ds = xr.open_mfdataset("data/*.nc", combine="by_coords") # 合并多个文件
示例代码
# 批量读取 NetCDF 文件
import xarray as xr
ds = xr.open_mfdataset("data/*.nc", engine="h5netcdf", parallel=True)
# 批量保存数据集
xr.save_mfdataset([ds1, ds2], ["output1.nc", "output2.nc"], encoding={var: {"zlib": True}})
(二)保存数据
# 保存为 NetCDF 文件并启用压缩
ds.to_netcdf("output.nc", encoding={"temp": {"zlib": True}})
八、性能优化
表:数据分块与性能优化
操作/方法 | 功能 | 输入参数 | 输出参数 |
---|---|---|---|
.chunk() |
设置数据分块大小 | chunks : 分块字典(如 {"time": 10, "lat": 100} ) |
带分块的 xarray.DataArray |
.compute() |
触发延迟计算 | 无 | 实际计算结果(xarray.DataArray 或 xarray.Dataset ) |
(一)分块处理(Dask)
import dask.array as da
ds = ds.chunk({"time": 10}) # 将时间维度分块
(二) 内存优化
- 使用
.persist()
或.compute()
控制计算时机。 - 避免不必要的中间变量。
示例代码
# 设置分块
da_chunked = da.chunk({"time": 10, "lat": 100})
# 触发计算
result = da_chunked.mean().compute()
九、总结
(一)核心流程
- 读取数据 → 2. 访问变量 → 3. 选择/切片 → 4. 计算/分析 → 5. 保存/可视化
(二)关键优势
- 标签化操作:通过维度名和坐标直接访问数据。
- 高效处理:支持多维数据、地理信息和大文件分块。
- 易用性:与 Pandas、Matplotlib 无缝集成。
(三)注意事项
- 使用
.sel()
和.isel()
时注意维度名称和索引范围。 - 大数据集需结合 Dask 分块处理(
.chunk()
)。 - 保存时启用压缩(
zlib=True
)可减少文件体积。