Model Size Gpu Storage Consumption

less than 1 minute read

Published:

LLM训练推理时GPU显存耗用量估计

推理时耗用量

目前模型的参数绝大多数都是float32类型,占用32位(bit)4个字节(Byte)。所以在全精度时一个参数就是4字节,1B个参数(就是\(1\times10^9\)个参数)共\(4\times1\times10^9B=4GB\)个字节,也就是需要消耗4GB显存。16精度的时候一个参数是16位2字节,显存消耗比全精度少一半,也就是1B个参数需要显存2GB。8精度的时候一个参数是8位1字节,再少一半,1B参数只需要1GB显存。最后4精度的时候一个参数4位0.5字节,1B参数只需要500MB即0.5GB显存。

以llama2为例。

Model Sizefull size1684
llama2 7B28GB14GB7GB3.5GB
llama2 13B52GB26GB13GB6.5GB
llama2 70B280GB140GB70GB35GB

当然以上只是向GPU加载模型使用的显存消耗,实际推理时还需要保存输入输出数据等,这也会消耗一部分显存,所以应当再额外空出一些显存。

训练时耗用量

训练时GPU需要保存模型,优化器状态,前向激活值和临时缓存。

  • 模型这里包括模型参数和梯度,均与参数规模一致
  • 前向激活值这里包括前向函数中保留的用于反向计算梯度的值,与batch size有关
  • 优化器状态这里包括动量,优化器权重等,控制器权重4Bytes,adam m 4Bytes,v 4Bytes
  • 临时缓存包括临时缓冲区和显存碎片

开启zero3且不offload时,全参数微调至少需要显存可以估计为16GB\(\times\)参数量,7B模型就需要7\(\times\)16=112GB(指仅考虑将模型、梯度和优化器放下需要的显存,前向计算还需要额外的显存)。

Checkpoint大小计算

保存checkpoint的时候只需要模型参数和优化器状态就行了,使用AdamW训练时优化器的参数量时模型本身的两倍,所以全精度保存checkpoint需要参数量\(\times5\times\)4GB,在用其他精度保存的时候,优化器精度一般保持不变,所以计算方式(16精度2字节为例)应该是(参数量\(\times4\times2\)+参数量\(\times2\))GB。