训练 API

deepspeed.initialize() 在其第一个参数(类型为 DeepSpeedEngine)中返回一个训练引擎。此引擎用于推进训练过程。

for step, batch in enumerate(data_loader):
    #forward() method
    loss = model_engine(batch)

    #runs backpropagation
    model_engine.backward(loss)

    #weight update
    model_engine.step()

前向传播

反向传播

优化器步骤

梯度累积

模型保存

此外,当创建 DeepSpeed 检查点时,会在其中添加一个脚本 zero_to_fp32.py,该脚本可用于将 fp32 主权重重建为单个 PyTorch state_dict 文件。