训练 API
deepspeed.initialize()
在其第一个参数(类型为 DeepSpeedEngine
)中返回一个训练引擎。此引擎用于推进训练过程。
for step, batch in enumerate(data_loader):
#forward() method
loss = model_engine(batch)
#runs backpropagation
model_engine.backward(loss)
#weight update
model_engine.step()
前向传播
反向传播
优化器步骤
梯度累积
模型保存
此外,当创建 DeepSpeed 检查点时,会在其中添加一个脚本 zero_to_fp32.py
,该脚本可用于将 fp32 主权重重建为单个 PyTorch state_dict
文件。