使用 Mosaic 流式处理加载数据

本文介绍如何使用 Mosaic 流式处理将数据从 Apache Spark 转换为与 PyTorch 兼容的格式。

Mosaic 流式处理是一个开源数据加载库。 它可直接从已加载为 Apache Spark 数据帧的数据集对深度学习模型进行单节点或分布式训练和评估。 Mosaic 流式处理主要支持 Mosaic Composer,但也与本机 PyTorch、PyTorch Lightning 和 TorchDistributor 集成。 与传统 PyTorch DataLoaders 相比,Mosaic 流式处理提供了一系列优势,包括:

  • 与任何数据类型(包括图像、文本、视频和多模式数据)的兼容性。

  • 支持主要云存储提供程序(AWS、OCI、Azure、Databricks UC 卷和任何与 S3 兼容的对象存储,例如 Cloudflare R2、Coreweave、Backblaze b2 等)

  • 最大限度地保证正确性,以及最大程度地提升性能、灵活性和易用性。 有关详细信息,请查看相应的主要功能页。

有关 Mosaic 流式处理的一般信息,请查看流式处理 API 文档

注意

Mosaic 流式处理已预安装到 Databricks Runtime 15.2 ML 及更高版本。

使用 Mosaic 流式处理从 Spark 数据帧加载数据

Mosaic 流式处理提供了一个简单的工作流,用于从 Apache Spark 转换为 Mosaic 数据分片 (MDS) 格式,然后可以加载该格式在分布式环境中使用。

建议的工作流为:

  1. 使用 Apache Spark 来加载数据,还可以选择对数据进行预处理。
  2. 使用 streaming.base.converters.dataframe_to_mds 将数据帧保存到磁盘进行暂时存储和/或保存到 Unity Catalog 卷进行持久存储。 此数据将以 MDS 格式存储,并且可以通过对压缩和哈希的支持进行进一步优化。 高级用例还可以包括使用 UDF 对数据进行预处理。 有关详细信息,请查看将 Spark 数据帧转换为 MDS 的教程
  3. 使用 streaming.StreamingDataset 将必要的数据加载到内存中。 StreamingDataset 是 PyTorch 的 IterableDataset 的一个版本,它具有可弹性确定的随机处理,可实现快速的中时期恢复。 有关详细信息,请查看 StreamingDataset 文档
  4. 使用 streaming.StreamingDataLoader 加载训练/评估/测试所需的数据。 StreamingDataLoader 是 PyTorch 的 DataLoader 的一个版本,它提供额外的检查点/恢复接口,用于跟踪此设置级别中模型看到的示例数。

有关端到端示例,请参阅以下笔记本:

使用 Mosaic 流式处理笔记本简化从 Spark 到 PyTorch 的数据加载

获取笔记本

疑难解答:身份验证错误

如果使用 StreamingDataset 从 Unity Catalog 卷加载数据时看到以下错误,请设置环境变量,如下所示。

ValueError: default auth: cannot configure default credentials, please check https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication to configure credentials for your preferred authentication method.

注意

如果使用 TorchDistributor 运行分布式训练时看到此错误,还必须在工作器节点上设置环境变量。

db_host = "https://your-databricks-host.databricks.com"
db_token = "YOUR API TOKEN" # Create a token with either method from https://docs.databricks.com/en/dev-tools/auth/index.html#databricks-authentication-methods

def your_training_function():
  import os
  os.environ['DATABRICKS_HOST'] = db_host
  os.environ['DATABRICKS_TOKEN'] = db_token

# The above function can be distributed with TorchDistributor:
# from pyspark.ml.torch.distributor import TorchDistributor
# distributor = TorchDistributor(...)
# distributor.run(your_training_function)