【数据处理】xarray 数据处理教程:从入门到精通

发布于:2025-05-15 ⋅ 阅读:(13) ⋅ 点赞:(0)

xarray 数据处理教程:从入门到精通

一、简介

xarray 是 Python 中用于处理多维数组数据的库,特别适用于带有标签(坐标)的科学数据(如气象、海洋、遥感等)。它基于 NumPy 和 Pandas,支持高效的数据操作、分析和可视化。

核心优势

  • 标签化操作:通过维度名和坐标直接访问数据,无需记忆索引位置。
  • 多维支持:天然支持多维数组(如时间、纬度、经度)。
  • 集成工具:内置 NetCDF、HDF5 等格式读写,支持 Dask 处理大文件。
  • 可视化:与 Matplotlib 深度集成,简化数据绘图流程。

二、安装与导入

1. 安装

pip install xarray netCDF4 dask rioxarray

或使用 Conda:

conda install -c conda-forge xarray netCDF4 dask rioxarray

2. 导入库

import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

三、数据结构

(一)DataArray

  • 定义:带坐标的 N 维数组,类似带标签的 NumPy 数组。
  • 示例代码
import xarray as xr
import numpy as np

data = np.random.rand(12, 5, 100, 200)
coords = {
    'time': np.arange(12),
    'sample': np.arange(5),
    'lat': np.linspace(-90, 90, 100),
    'lon': np.linspace(-180, 180, 200)
}
da = xr.DataArray(data, dims=['time', 'sample', 'lat', 'lon'], coords=coords)
  • 输出结果
<xarray.DataArray (time: 12, sample: 5, lat: 100, lon: 200)>
array([[[[...]],  # 12个时间步 × 5个样本 × 100纬度 × 200经度的随机值
        ...,
        [[...]]],
       ...,
       [[...]]])
Coordinates:
  * time      (time) int64 0 1 2 ... 11
  * sample    (sample) int64 0 1 2 3 4
  * lat       (lat) float64 -90.0 -89.1 -88.2 ... 88.2 89.1 90.0
  * lon       (lon) float64 -180.0 -179.1 -178.2 ... 178.2 179.1 180.0
  • 表1:数据结构DataArray
操作/方法 功能 输入参数 输出参数
xarray.DataArray 创建带有维度和坐标的多维数组(如海温数据) data: 数组数据;dims: 维度名列表;coords: 坐标字典 生成的 xarray.DataArray 对象

(二) Dataset

  • 定义:类似字典的容器,包含多个 DataArray(变量),共享坐标。

  • 表:xarray.Dataset 操作总结

操作/方法 功能 输入参数 输出参数
xr.Dataset 创建多变量数据集 变量字典、坐标字典 xarray.Dataset 对象
.sel() / .isel() 按标签或索引选择数据 维度名和值 子数据集或子数组
.mean() / .std() 计算维度统计量 需求平均的维度名 统计后的数据集或数据数组
.to_netcdf() 保存为 NetCDF 文件 文件路径和写入模式 无返回值(保存文件)
xr.open_dataset() 读取 NetCDF 文件 文件路径和读取引擎 xarray.Dataset 对象
.merge() 合并两个数据集 另一个 Dataset 对象 合并后的 xarray.Dataset
.apply() 应用自定义函数 自定义函数和输入维度 应用后的数据集
.chunk() 设置数据分块 分块大小字典 分块后的 xarray.Dataset
  • 创建示例
import xarray as xr
import numpy as np
import pandas as pd

ds = xr.Dataset(
    {
        "temperature": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
        "humidity": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
    },
    coords={
        "time": pd.date_range("2025-01-01", periods=3),
        "lat": np.linspace(-90, 90, 10),
        "lon": np.linspace(-180, 180, 20),
    }
)
  • 输出结果
<xarray.Dataset>
Dimensions:     (time: 3, lat: 10, lon: 20)
Coordinates:
  * time        (time) datetime64[ns] 2025-01-01 2025-01-02 ...
  * lat         (lat) float64 -90.0 -81.0 ... 81.0 90.0
  * lon         (lon) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
    temperature (time, lat, lon) float64 0.1234 0.5678 ...
    humidity    (time, lat, lon) float64 0.9876 0.4321 ...

(三)关键说明

  1. Dataset vs DataArray
    • xarray.Dataset 适合处理多变量数据(如温度、湿度、降水)。
    • xarray.DataArray 适合单一变量的多维数组操作。
  2. 工作流示例
    # 读取数据并选择子集
    ds = xr.open_dataset("data.nc")
    subset = ds.sel(lat=slice(-90, -60), lon=slice(-180, -120))
    
    # 计算统计量并保存
    mean_ds = subset.mean(dim="time")
    mean_ds.to_netcdf("mean_data.nc")
    

四、数据操作

(一)索引与切片

表:索引与切片

操作/方法 功能 输入参数 输出参数
.isel() / .sel() 快速提取特定维度数据 isel: 按索引提取;sel: 按坐标标签提取 提取后的子数组(xarray.DataArray
1. 基于标签选择(.sel()

场景:从数据集中提取特定时间步和纬度范围的数据。

subset = ds.sel(time="2025-01-01", lat=-90)

输出结果

<xarray.Dataset>
Dimensions:     (lat: 1, lon: 20)
Coordinates:
    time        datetime64[ns] 2025-01-01
    lat         float64 -90.0
    lon         (lon) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
    temperature (lon) float64 0.1234 0.5678 ...
    humidity    (lon) float64 0.9876 0.4321 ...
2. 基于位置选择(.isel()

场景:从数据集中提取第一个时间步和前三个纬度的数据。

# 按索引位置选择数据
subset_isel = ds.isel(time=0, lat=slice(0, 3))

输出结果

<xarray.Dataset>
Dimensions:     (time: 1, lat: 3, lon: 20)
Coordinates:
    time        datetime64[ns] 2025-01-01
  * lat         (lat) float64 -90.0 -76.36 -62.73
  * lon         (lon) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
    temperature (time, lat, lon) float64 0.1234 0.5678 ...
3. .sel().isel() 联合使用

场景:结合标签和索引选择数据(例如,选择特定时间步和固定经度索引)。

# 按时间标签和经度索引选择数据
subset_mixed = ds.sel(time="2025-01-01").isel(lon=0)

输出结果

<xarray.Dataset>
Dimensions:     (time: 1, lat: 10, lon: 1)
Coordinates:
    time        datetime64[ns] 2025-01-01
  * lat         (lat) float64 -90.0 -81.0 ... 81.0 90.0
    lon         float64 -180.0
Data variables:
    temperature (time, lat, lon) float64 0.1234 0.5678 ...

4. 多维选择与切片

场景:同时选择多个维度(时间、纬度、经度)并使用切片操作。

# 按时间标签、纬度范围和经度切片选择数据
subset_slice = ds.sel(
    time="2025-01-01",
    lat=slice(-90, -60),
    lon=slice(-180, -120)
)

输出结果

<xarray.Dataset>
Dimensions:     (time: 1, lat: 3, lon: 7)
Coordinates:
    time        datetime64[ns] 2025-01-01
  * lat         (lat) float64 -90.0 -76.36 -62.73
  * lon         (lon) float64 -180.0 -163.6 ... -126.3 -120.0
Data variables:
    temperature (time, lat, lon) float64 0.1234 0.5678 ...
5. 关键说明
  1. .sel() vs .isel()

    • .sel():使用坐标标签(如 time="2025-01-01"lat=-90)进行选择,适合已知具体坐标的场景。
    • .isel():使用索引位置(如 time=0lat=slice(0, 3))进行选择,适合已知数组索引的场景。
  2. 切片操作

    • 可以通过 slice(start, end) 实现对维度的范围选择(例如 lat=slice(-90, -60))。
    • 切片是左闭右开的,即包含 start,不包含 end
  3. 多维联合选择

    • 可以联合使用 .sel().isel(),例如先按标签选择时间,再按索引选择经度。
    • 也可以通过链式调用实现多步选择(如 ds.sel(...).isel(...))。

(二) 数据计算

1. 聚合运算
(1) 计算单个维度的平均值

场景:从数据集中计算时间维度(time)的平均值。

import xarray as xr
import numpy as np
import pandas as pd

# 创建示例数据集
ds = xr.Dataset(
    {
        "temperature": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
        "humidity": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
    },
    coords={
        "time": pd.date_range("2025-01-01", periods=3),
        "lat": np.linspace(-90, 90, 10),
        "lon": np.linspace(-180, 180, 20),
    }
)

# 计算时间维度的平均值
mean_time = ds.mean(dim="time")

输出结果

<xarray.Dataset>
Dimensions:     (lat: 10, lon: 20)
Coordinates:
  * lat         (lat) float64 -90.0 -81.0 -72.0 ... 72.0 81.0 90.0
  * lon         (lon) float64 -180.0 -171.0 -162.0 ... 162.0 171.0 180.0
Data variables:
    temperature (lat, lon) float64 0.4567 0.8901 ...
    humidity    (lat, lon) float64 0.7654 0.3210 ...

(2) 计算多个维度的平均值

场景:从数据集中计算时间和纬度维度的平均值。

# 计算时间和纬度维度的平均值
mean_time_lat = ds.mean(dim=["time", "lat"])

输出结果

<xarray.Dataset>
Dimensions:     (lon: 20)
Coordinates:
  * lon         (lon) float64 -180.0 -171.0 -162.0 ... 162.0 171.0 180.0
Data variables:
    temperature (lon) float64 0.6789 ...
    humidity    (lon) float64 0.5432 ...
(3) 计算单个维度的标准差

场景:从数据集中计算纬度维度(lat)的标准差。

# 计算纬度维度的标准差
std_lat = ds.std(dim="lat")

输出结果

<xarray.Dataset>
Dimensions:     (time: 3, lon: 20)
Coordinates:
  * time        (time) datetime64[ns] 2025-01-01 ... 2025-01-03
  * lon         (lon) float64 -180.0 -171.0 -162.0 ... 162.0 171.0 180.0
Data variables:
    temperature (time, lon) float64 0.2345 0.6789 ...
    humidity    (time, lon) float64 0.3456 0.7890 ...
(4) 计算多个维度的标准差

场景:从数据集中计算时间和经度维度的标准差。

# 计算时间和经度维度的标准差
std_time_lon = ds.std(dim=["time", "lon"])

输出结果

<xarray.Dataset>
Dimensions:     (lat: 10)
Coordinates:
  * lat         (lat) float64 -90.0 -81.0 -72.0 ... 72.0 81.0 90.0
Data variables:
    temperature (lat) float64 0.1234 ...
    humidity    (lat) float64 0.4567 ...
(5) 忽略缺失值计算统计量

场景:数据集中包含缺失值(NaN),需要在计算时跳过缺失值。

# 创建包含缺失值的数据集
ds_nan = xr.Dataset(
    {
        "temperature": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
    },
    coords={
        "time": pd.date_range("2025-01-01", periods=3),
        "lat": np.linspace(-90, 90, 10),
        "lon": np.linspace(-180, 180, 20),
    }
)

# 随机插入缺失值
ds_nan.temperature.values[0, 0, 0] = np.nan

# 计算时间维度的平均值(跳过缺失值)
mean_time_skipna = ds_nan.mean(dim="time", skipna=True)

输出结果

<xarray.Dataset>
Dimensions:     (lat: 10, lon: 20)
Coordinates:
  * lat         (lat) float64 -90.0 -81.0 -72.0 ... 72.0 81.0 90.0
  * lon         (lon) float64 -180.0 -171.0 -162.0 ... 162.0 171.0 180.0
Data variables:
    temperature (lat, lon) float64 0.4567 0.8901 ...

(6) 关键说明
  1. .mean().std() 的区别

    • .mean():计算指定维度的平均值
    • .std():计算指定维度的标准差,默认为样本标准差(ddof=1)。
  2. 维度选择

    • 可以指定单个维度(如 dim="time")或多个维度(如 dim=["time", "lat"])。
    • 维度减少后,输出数据集的维度会相应调整(如从 (time, lat, lon) 变为 (lat, lon))。
  3. 缺失值处理

    • 通过 skipna=True 可以跳过缺失值(NaN)进行计算,避免因缺失值导致整个统计结果为 NaN
2. 算术运算
# 温度乘以 2,降水加 10
new_ds = ds * 2 + 10
3. 数据重塑

总结表格

操作/方法 功能 输入参数 输出参数
.stack() 将多个维度堆叠成一个新维度 new_dim_name: 新维度名;dim: 原始维度列表 维度被堆叠后的 xarray.Dataset
.transpose() 调整维度顺序(不改变维度数量) *dims: 新维度顺序 维度顺序调整后的 xarray.Dataset

(1) .stack():将多个维度堆叠成一个新维度

场景:将 latlon 维度堆叠成一个名为 space 的新维度。
示例代码

import xarray as xr
import numpy as np
import pandas as pd

# 创建示例数据集
ds = xr.Dataset(
    {
        "temperature": (["time", "lat", "lon"], np.random.rand(3, 10, 20)),
    },
    coords={
        "time": pd.date_range("2025-01-01", periods=3),
        "lat": np.linspace(-90, 90, 10),
        "lon": np.linspace(-180, 180, 20),
    }
)

# 将 lat 和 lon 堆叠成 space 维度
stacked_ds = ds.stack(space=["lat", "lon"])

输出结果

<xarray.Dataset>
Dimensions:     (time: 3, space: 200)
Coordinates:
  * time        (time) datetime64[ns] 2025-01-01 2025-01-02 ...
    space       (space) MultiIndex
      - lat     (space) float64 -90.0 -90.0 ... 90.0 90.0
      - lon     (space) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
    temperature (time, space) float64 0.1234 0.5678 ...

(2) .transpose():调整维度顺序

场景:将 time, lat, lon 的维度顺序调整为 lon, lat, time
示例代码:

# 调整维度顺序
transposed_ds = ds.transpose("lon", "lat", "time")

输出结果

<xarray.Dataset>
Dimensions:     (lon: 20, lat: 10, time: 3)
Coordinates:
  * lon         (lon) float64 -180.0 -171.0 ... 171.0 180.0
  * lat         (lat) float64 -90.0 -81.0 ... 81.0 90.0
  * time        (time) datetime64[ns] 2025-01-01 2025-01-02 ...
Data variables:
    temperature (lon, lat, time) float64 0.1234 0.5678 ...
(3) .stack() + .transpose() 联合使用

场景:先将 latlon 堆叠成 space,再调整维度顺序为 space, time
示例代码:

# 堆叠后调整维度顺序
stacked_transposed_ds = ds.stack(space=["lat", "lon"]).transpose("space", "time")

输出结果

<xarray.Dataset>
Dimensions:     (space: 200, time: 3)
Coordinates:
    space       (space) MultiIndex
      - lat     (space) float64 -90.0 -90.0 ... 90.0 90.0
      - lon     (space) float64 -180.0 -171.0 ... 171.0 180.0
  * time        (time) datetime64[ns] 2025-01-01 2025-01-02 ...
Data variables:
    temperature (space, time) float64 0.1234 0.5678 ...

(4) 关键说明
  1. .stack() vs .transpose()

    • .stack()减少维度数量,将多个维度合并为一个新维度(如 latlonspace)。
    • .transpose()不改变维度数量,仅调整维度的排列顺序(如 time, lat, lonlon, lat, time)。
  2. 应用场景

    • .stack()
      • 将多维数据转换为二维,便于进行某些计算(如机器学习模型输入)。
      • 简化高维数据的可视化(如将 latlon 合并为 space 后绘图)。
    • .transpose()
      • 调整数据维度顺序以匹配其他数据集或模型的输入格式。
      • 提高代码可读性,使维度顺序更符合逻辑(如先经度后纬度)。
  3. 注意事项

    • .stack() 会生成 MultiIndex,可通过 .unstack() 恢复原始维度。
    • .transpose() 不会修改原始数据,而是返回一个新对象(惰性操作)。
4. 数据聚合
操作/方法 功能 输入参数 输出参数
.groupby() / .mean() 按维度分组计算均值 group: 分组维度(如 time.month);dim: 聚合维度 聚合后的 xarray.DataArray
.resample() 时间序列重采样(如日→月) freq: 重采样频率(如 MS 表示月初);dim: 时间维度名 重采样后的 xarray.DataArray

示例代码

# 按月份分组计算均值
monthly_mean = da.groupby("time.month").mean(dim="time")

# 时间序列重采样(日→月)
monthly_resample = da.resample(time="MS").mean()
5. 合并两个数据集
ds2 = xr.Dataset({"precipitation": (["time", "lat", "lon"], np.random.rand(3, 10, 20))})
ds_merged = ds.merge(ds2)

输出结果

<xarray.Dataset>
Dimensions:     (time: 3, lat: 10, lon: 20)
Coordinates:
  * time        (time) datetime64[ns] 2025-01-01 2025-01-02 ...
  * lat         (lat) float64 -90.0 -81.0 ... 81.0 90.0
  * lon         (lon) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
    temperature (time, lat, lon) float64 0.1234 0.5678 ...
    humidity    (time, lat, lon) float64 0.9876 0.4321 ...
    precipitation (time, lat, lon) float64 0.3456 0.7890 ...
6. 应用自定义函数
def custom_func(arr):
    return arr.max() - arr.min()

ds_custom = ds.apply(custom_func)

输出结果

<xarray.Dataset>
Dimensions:     (time: 3, lat: 10, lon: 20)
Coordinates:
  * time        (time) datetime64[ns] 2025-01-01 2025-01-02 ...
  * lat         (lat) float64 -90.0 -81.0 ... 81.0 90.0
  * lon         (lon) float64 -180.0 -171.0 ... 171.0 180.0
Data variables:
    temperature (time, lat, lon) float64 0.4321 0.8765 ...
    humidity    (time, lat, lon) float64 0.5432 0.9876 ...

五、数据可视化

表:可视化方法

操作/方法 功能 输入参数 输出参数
.plot() 快速可视化(等值线图、色阶图) x, y: 维度名;cbar_kwargs: 颜色条参数;transform: 投影转换 matplotlib.axes.Axes 对象
.plot.scatter() 散点图可视化 x, y: 维度名;c: 颜色变量;size: 点大小 matplotlib.axes.Axes 对象
  • 示例代码
# 绘制等值线图
da.plot.contourf(x="lon", y="lat", cmap="viridis")

# 绘制散点图
da.plot.scatter(x="lon", y="lat", c="temperature", size="precipitation")

(一)二维分布图

# 绘制温度的空间分布
ds["temp"].isel(time=0).plot(cmap="viridis")
plt.title("Temperature Distribution")
plt.show()

(二)时间序列图

# 绘制单个网格点的时间序列
ds["temp"].sel(lat=40, lon=100).plot.line(x="time")
plt.title("Temperature Time Series")
plt.show()

六、高级功能

(一) 缺失值处理

1. 填充缺失值
# 用 0 填充缺失值
filled_temp = ds["temp"].fillna(0)
2. 插值
# 使用线性插值填充缺失值
interpolated = ds["temp"].interpolate_na(dim="lat", method="linear")

(二) 时间重采样

# 将日数据重采样为月均值
monthly_mean = ds.resample(time="1M").mean()

(三)地理信息处理

1. 设置坐标系
import rioxarray
ds.rio.write_crs("EPSG:4326", inplace=True)  # 设置为 WGS84 坐标系
2. 绘制地理投影图
import cartopy.crs as ccrs
ax = plt.axes(projection=ccrs.PlateCarree())
ds["temp"].isel(time=0).plot(ax=ax, transform=ccrs.PlateCarree())
ax.coastlines()
plt.show()

七、数据输入与输出

表:NetCDF 读取与保存

操作/方法 功能 输入参数 输出参数
xr.open_dataset 读取单个 NetCDF 文件 filename: 文件路径;engine: 读取引擎(如 netcdf4 xarray.Dataset 对象
xr.to_netcdf 保存数据为 NetCDF 文件 filename: 保存路径;mode: 写入模式(如 w 表示覆盖) 无返回值(直接写入文件)
xr.open_mfdataset 批量读取多文件数据集 paths: 文件路径列表;engine: 读取引擎(如 h5netcdf);parallel: 是否并行读取;preprocess: 预处理函数 合并后的 xarray.Dataset
xr.save_mfdataset 批量保存数据集到文件 datasets: 数据集列表;paths: 保存路径列表;encoding: 变量编码参数(如压缩设置) 无返回值(直接写入文件)

(一)读取 NetCDF 文件

ds = xr.open_dataset("data.nc")  # 读取单个文件
ds = xr.open_mfdataset("data/*.nc", combine="by_coords")  # 合并多个文件

示例代码

# 批量读取 NetCDF 文件
import xarray as xr
ds = xr.open_mfdataset("data/*.nc", engine="h5netcdf", parallel=True)

# 批量保存数据集
xr.save_mfdataset([ds1, ds2], ["output1.nc", "output2.nc"], encoding={var: {"zlib": True}})

(二)保存数据

# 保存为 NetCDF 文件并启用压缩
ds.to_netcdf("output.nc", encoding={"temp": {"zlib": True}})

八、性能优化

表:数据分块与性能优化

操作/方法 功能 输入参数 输出参数
.chunk() 设置数据分块大小 chunks: 分块字典(如 {"time": 10, "lat": 100} 带分块的 xarray.DataArray
.compute() 触发延迟计算 实际计算结果(xarray.DataArrayxarray.Dataset

(一)分块处理(Dask)

import dask.array as da
ds = ds.chunk({"time": 10})  # 将时间维度分块

(二) 内存优化

  • 使用 .persist().compute() 控制计算时机。
  • 避免不必要的中间变量。

示例代码

# 设置分块
da_chunked = da.chunk({"time": 10, "lat": 100})

# 触发计算
result = da_chunked.mean().compute()

九、总结

(一)核心流程

  1. 读取数据 → 2. 访问变量 → 3. 选择/切片 → 4. 计算/分析 → 5. 保存/可视化

(二)关键优势

  • 标签化操作:通过维度名和坐标直接访问数据。
  • 高效处理:支持多维数据、地理信息和大文件分块。
  • 易用性:与 Pandas、Matplotlib 无缝集成。

(三)注意事项

  • 使用 .sel().isel() 时注意维度名称和索引范围。
  • 大数据集需结合 Dask 分块处理(.chunk())。
  • 保存时启用压缩(zlib=True)可减少文件体积。

网站公告

今日签到

点亮在社区的每一天
去签到