使用 Mosaic 流式处理加载数据
本文介绍如何使用 Mosaic 流式处理将数据从 Apache Spark 转换为与 PyTorch 兼容的格式。
Mosaic 流式处理是一个开源数据加载库。 它可直接从已加载为 Apache Spark 数据帧的数据集对深度学习模型进行单节点或分布式训练和评估。 Mosaic 流式处理主要支持 Mosaic Composer,但也与本机 PyTorch、PyTorch Lightning 和 TorchDistributor 集成。 与传统 PyTorch DataLoaders 相比,Mosaic 流式处理提供了一系列优势,包括:
与任何数据类型(包括图像、文本、视频和多模式数据)的兼容性。
支持主要云存储提供程序(AWS、OCI、Azure、Databricks UC 卷和任何与 S3 兼容的对象存储,例如 Cloudflare R2、Coreweave、Backblaze b2 等)
最大限度地保证正确性,以及最大程度地提升性能、灵活性和易用性。 有关详细信息,请查看相应的主要功能页。
有关 Mosaic 流式处理的一般信息,请查看流式处理 API 文档。
注意
Mosaic 流式处理已预安装到 Databricks Runtime 15.2 ML 及更高版本。
使用 Mosaic 流式处理从 Spark 数据帧加载数据
Mosaic 流式处理提供了一个简单的工作流,用于从 Apache Spark 转换为 Mosaic 数据分片 (MDS) 格式,然后可以加载该格式在分布式环境中使用。
建议的工作流为:
- 使用 Apache Spark 来加载数据,还可以选择对数据进行预处理。
- 使用
streaming.base.converters.dataframe_to_mds
将数据帧保存到磁盘进行暂时存储和/或保存到 Unity Catalog 卷进行持久存储。 此数据将以 MDS 格式存储,并且可以通过对压缩和哈希的支持进行进一步优化。 高级用例还可以包括使用 UDF 对数据进行预处理。 有关详细信息,请查看将 Spark 数据帧转换为 MDS 的教程。 - 使用
streaming.StreamingDataset
将必要的数据加载到内存中。StreamingDataset
是 PyTorch 的 IterableDataset 的一个版本,它具有可弹性确定的随机处理,可实现快速的中时期恢复。 有关详细信息,请查看 StreamingDataset 文档。 - 使用
streaming.StreamingDataLoader
加载训练/评估/测试所需的数据。StreamingDataLoader
是 PyTorch 的 DataLoader 的一个版本,它提供额外的检查点/恢复接口,用于跟踪此设置级别中模型看到的示例数。
有关端到端示例,请参阅以下笔记本:
使用 Mosaic 流式处理笔记本简化从 Spark 到 PyTorch 的数据加载
疑难解答:身份验证错误
如果使用 StreamingDataset
从 Unity Catalog 卷加载数据时看到以下错误,请设置环境变量,如下所示。
ValueError: default auth: cannot configure default credentials, please check https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication to configure credentials for your preferred authentication method.
注意
如果使用 TorchDistributor
运行分布式训练时看到此错误,还必须在工作器节点上设置环境变量。
db_host = "https://your-databricks-host.databricks.com"
db_token = "YOUR API TOKEN" # Create a token with either method from https://docs.databricks.com/en/dev-tools/auth/index.html#databricks-authentication-methods
def your_training_function():
import os
os.environ['DATABRICKS_HOST'] = db_host
os.environ['DATABRICKS_TOKEN'] = db_token
# The above function can be distributed with TorchDistributor:
# from pyspark.ml.torch.distributor import TorchDistributor
# distributor = TorchDistributor(...)
# distributor.run(your_training_function)