• 问答
  • 技术
  • 实践
  • 资源
ViT-int8 on TVM:提速 4.6 倍,比 TRT 快 1.5 倍
技术讨论

作者丨火柴天堂
来源丨https://zhuanlan.zhihu.com/p/365686106
编辑丨极市平台

TL;DR

5个步骤教你在TVM里优化ViT的int8实现,提速4.6倍,比TRT快1.5倍。

背景知识:ViT模型及其速度

Transformer 模型在 NLP 领域得到了广泛的应用,去年谷歌的一项工作将 Transformer 模型应用在了视觉领域,并在 ImageNet 等图像分类数据集上取得了出色的效果,该模型被叫做 Vision Transformer (ViT) (论文链接:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)。

那ViT在硬件上inference的速度怎么样呢?

以复现的输入大小为 1x3x224x224 的 ViT 模型为基础,我们在 GTX1660 显卡上分别应用 cudnn、tensorrt、tvm 三种后端测试了模型速度,结果对比如下:

后端 B16_224 B32_224
cuda10.2_cudnn7.6.5_fp32 23.73 ms 10.03 ms
cuda10.2_tensorrt7.1_fp32 16.77 ms 5.77 ms
cuda10.2_tensorrt7.1_int8 19.32 ms 6.54 ms
tvm_tune500_fp32 59.64 ms 9.27 ms
tvm_tune500_int8 59.41 ms 9.14 ms

从速度结果可以看到:在全精度 fp32 时,模型在 tensorrt 上的表现最好,而 tvm 调优的结果不尽如人意;对于 int8 量化,tensorrt 的量化结果竟然比浮点还要差,另外 tvm 量化后的结果和量化之前 fp32 的结果也相差无几。导致 vit 模型量化后速度表现拉跨的原因主要是,vit 模型的计算量集中在 batch_matmul 算子上,而无论是 tensorrt 还是 tvm,其对batch_matmul 算子的量化支持并不是特别好;对于 tensorrt 这种闭源库我们显然无能为力。

而对于 tvm,由于其开源,可以尝试增加对 batch_matmul 的量化支持。

我们决定自己试一把。

ps:下述工作已经被完整merge到TVM里了,具体可见:https://github.com/apache/tvm/pull/7814

STEP 1:QuantizeAnnotate

  • 量化节点标注的 pass,告诉 relay 一些算子需要量化,并根据算子功能插入模拟量化节点,模拟量化节点模拟了由浮点数映射到定点数的误差
  • 相关文件:python/tvm/relay/quantize/_annotate.py
  • 这里我们增加 batch_matmul 的 rewrite 函数:
@register_annotate_function("nn.batch_matmul")
def batch_matmul_rewrite(ref_call, new_args, ctx):
    """Rewrite function for batch_matmul"""
    if quantize_context().check_to_skip(ref_call):
        return None

    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

    if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
        if _analysis.check_constant(lhs_expr):
            lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.WEIGHT)
        else:
            lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

    if rhs_kind is None or rhs_kind == QAnnotateKind.ACTIVATION:
        if _analysis.check_constant(rhs_expr):
            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
        else:
            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)

    expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

STEP 2:QuantizeCalibrate

  • 量化校准的 pass,调整量化的阈值和缩放比,避免模型量化后精度下降
  • 相关文件:python/tvm/relay/quantize/_calibrate.py
  • 由于 tvm 仅有 kl 散度校准算法,且其在 ViT 模型量化时表现不佳,因此我们增加了一个简单的 percentile 校准方法,以挽救 ViT 模型精度:
def _find_scale_by_percentile(arr, percentile=0.99999):
    assert isinstance(arr, np.ndarray)
    x = np.abs(arr)
    max_k = int(x.size * percentile)
    return np.partition(x, max_k)[max_k]

def _percentile_scale(mod, dataset):
    cfg = quantize.current_qconfig()
    chunk_by = cfg.calibrate_chunk_by
    scales = []
    for samples in collect_stats(mod, dataset, chunk_by):
        logging.info("finding threshold with percentile for calibration...")
        with mp.Pool() as pool:
            scales += list(pool.map(_find_scale_by_percentile, samples))

    def func(_):
        scale = scales[func.scale_idx]
        func.scale_idx += 1
        return scale

    func.scale_idx = 0

STEP 3:QuantizeRealize

  • 量化实现的 pass,将 fp32 计算图转换为真实的低比特定点数的计算图
  • 相关文件:src/relay/quantize/realize.cc
  • 这里我们增加对 batch_matmul 支持的 Realize 函数:
Expr BatchMatmulRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
  const QConfig& cfg = QConfig::Current();
  ICHECK_EQ(new_args.size(), 2);
  if (!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()) {
    return Expr(nullptr);
  }
  const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
  const auto* rhs = new_args[1].as<QRealizeIntExprNode>();

  Expr ldata = lhs->data;
  Expr rdata = rhs->data;
  DataType dtype = cfg->dtype_input;

  if (lhs->dtype != dtype) {
    ldata = Cast(ldata, dtype);
  }
  if (rhs->dtype != dtype) {
    rdata = Cast(rdata, dtype);
  }

  const auto ref_attrs = ref_call->attrs.as<BatchMatmulAttrs>();
  auto attrs = make_object<BatchMatmulAttrs>();
  *attrs = *ref_attrs;
  DataType out_dtype = cfg->dtype_activation;
  attrs->out_dtype = out_dtype;

  Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
  Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
  Expr dom_scale = FoldConstantOpt(mul);
  return QRealizeIntExpr(ret, dom_scale, out_dtype);
}

RELAY_REGISTER_OP("nn.batch_matmul")
    .set_attr<FForwardRewrite>("FQRealizeRewrite", BatchMatmulRealize);
  • 由于 batch_matmul 的 int8 计算涉及到 out_dtype,因此同时需要更改 include/tvm/relay/attrs/nn.h 中的 BatchMatmulAttrs 和 src/relay/op/nn/nn.c 中的 MakeBatchMatmul 的定义:
/*! \brief Attributes for batch matmul operator */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
  tvm::String auto_scheduler_rewritten_layout;  // The layout after auto-scheduler's layout rewrite
  DataType out_dtype;

  TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {
    // use 0 bits to indicate none.
    TVM_ATTR_FIELD(out_dtype)
        .set_default(NullValue<DataType>())
        .describe("Output data type, set to explicit type under mixed precision setting");
  }
};

// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeBatchMatmul(Expr x, Expr y, DataType out_dtype) {
  auto attrs = make_object<BatchMatmulAttrs>();
  attrs->out_dtype = out_dtype;
  static const Op& op = Op::Get("nn.batch_matmul");
  return Call(op, {x, y}, Attrs(attrs), {});
}

STEP 4:topi-compute \& topi-schedule

  • 通过以上三个步骤,tvm 的 relay 图中的 batch_matmul_fp32 计算可以量化变成 batch_matmul_int8 计算,接下来需要实现 int8 算子的 compute 和 schedule
  • compute:用来描述算子的 tensor 计算过程
  • schedule:基于特定平台对于算子的计算进行调度,通过 tile,split,reorder,memory_cache 等操作,从而达到更快的运行效率
  • topi 全称为 tvm operator inventory,是 tvm 为多种平台提供的多种算子的计算和调度实现,我们这里将 batch_matmul_int8 的计算和调度实现注册进 topi 之中
  • 相关文件:python/tvm/topi/cuda/batch_matmul.py
  • 在具体实现中,我们仅考虑 batch_matmul_int8 的 cuda 平台,并使用 dp4a 指令实现 int8 计算调度:
@autotvm.register_topi_compute("batch_matmul_int8.cuda")
def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None):
    """Batch Matmul operator for int8 on CUDA"""
    if out_dtype is None:
        out_dtype = x.dtype

    x_shape = get_const_tuple(x.shape)
    y_shape = get_const_tuple(y.shape)
    assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul"

    XB, M, XK = x.shape
    YB, N, YK = y.shape
    assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
    assert XK == YK, "shapes of x and y is inconsistent"

    nB = tvm.te.max(XB, YB)
    nK = ((XK + 3) // 4) * 4
    reduce_k = te.reduce_axis((0, nK), name="k")

    # pad for _dp4a vectorize
    pad_x = te.compute(
        (XB, M, nK),
        lambda b, i, j: tvm.te.if_then_else(
            j >= XK, tvm.runtime.convert(0).astype(x.dtype), x[b, i, j]
        ),
    )
    pad_y = te.compute(
        (YB, N, nK),
        lambda b, i, j: tvm.te.if_then_else(
            j >= YK, tvm.runtime.convert(0).astype(y.dtype), y[b, i, j]
        ),
    )

    out = te.compute(
        (nB, M, N),
        lambda b, i, j: te.sum(
            pad_x[b if XB != 1 else 0, i, reduce_k].astype(out_dtype)
            * pad_y[b if YB != 1 else 0, j, reduce_k].astype(out_dtype),
            axis=[reduce_k],
        ),
        tag="batch_matmul_int8",
    )
    cfg.add_flop(XB * M * N * nK * 2)
    return out

@autotvm.register_topi_schedule("batch_matmul_int8.cuda")
def schedule_batch_matmul_int8(cfg, outs):
    """Batch Matmul schedule for int8 on CUDA"""
    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
    s = te.create_schedule([x.op for x in outs])

    def _callback(op):
        if "batch_matmul_int8" in op.tag:
            _schedule_batch_matmul_int8(cfg, s, op.output(0))

    traverse_inline(s, outs[0].op, _callback)
    return s

_dp4a = dp4a("shared", "shared", "local")

def _schedule_batch_matmul_int8(cfg, s, output):
    input_x, input_y = s[output].op.input_tensors

    B, M, K = get_const_tuple(input_x.shape)
    _, N, _ = get_const_tuple(input_y.shape)

    k_factor = 4
    assert K % k_factor == 0, "Input dimension must divide {}".format(k_factor)
    if K % 16 == 0:
        k_factor = 16

    cfg.define_split("tile_f", B, num_outputs=4)
    cfg.define_split("tile_m", M, num_outputs=4)
    cfg.define_split("tile_n", N, num_outputs=4)
    cfg.define_split("tile_k", K // k_factor, num_outputs=2)
    cfg.define_knob("auto_unroll_max_step", [0, 256, 512, 1024])

    batch_matmul_op = s.outputs[0]
    s[input_x].compute_inline()
    s[input_y].compute_inline()

    x_cache = s.cache_read(input_x, "shared", [batch_matmul_op])
    y_cache = s.cache_read(input_y, "shared", [batch_matmul_op])
    batch_matmul_cache = s.cache_write(batch_matmul_op.output(0), "local")

    # tile reduce axis
    ko = batch_matmul_cache.op.reduce_axis[0]
    ko, ki = s[batch_matmul_cache].split(ko, factor=4)
    ko, kt = cfg["tile_k"].apply(s, batch_matmul_cache, ko)
    # dp4a tensorize
    s[batch_matmul_cache].tensorize(ki, _dp4a)

    # tile axis
    f, m, n = batch_matmul_op.axis
    kernel_scope, f = s[batch_matmul_op].split(f, nparts=1)

    bf, vf, tf, fi = cfg["tile_f"].apply(s, batch_matmul_op, f)
    bm, vm, tm, mi = cfg["tile_m"].apply(s, batch_matmul_op, m)
    bn, vn, tn, ni = cfg["tile_n"].apply(s, batch_matmul_op, n)
    s[batch_matmul_op].reorder(bf, bm, bn, vf, vm, vn, tf, tm, tn, fi, mi, ni)

    # bind axis
    s[batch_matmul_op].bind(bf, tvm.te.thread_axis("blockIdx.z"))
    s[batch_matmul_op].bind(bm, tvm.te.thread_axis("blockIdx.y"))
    s[batch_matmul_op].bind(bn, tvm.te.thread_axis("blockIdx.x"))

    s[batch_matmul_op].bind(vf, tvm.te.thread_axis("vthread"))
    s[batch_matmul_op].bind(vm, tvm.te.thread_axis("vthread"))
    s[batch_matmul_op].bind(vn, tvm.te.thread_axis("vthread"))

    s[batch_matmul_op].bind(tf, tvm.te.thread_axis("threadIdx.z"))
    s[batch_matmul_op].bind(tm, tvm.te.thread_axis("threadIdx.y"))
    s[batch_matmul_op].bind(tn, tvm.te.thread_axis("threadIdx.x"))

    # cache compute at
    s[batch_matmul_cache].compute_at(s[batch_matmul_op], tn)
    fo, mo, no = batch_matmul_cache.op.axis[:3]
    s[batch_matmul_cache].reorder(ko, kt, fo, mo, no, ki)

    # for load in [splited_x_op, splited_y_op]
    for load in [x_cache, y_cache]:
        s[load].compute_at(s[batch_matmul_cache], ko)
        outer, inner = s[load].split(s[load].op.axis[-1], factor=k_factor)
        s[load].vectorize(inner)

        fused = s[load].op.axis[:-1] + [outer]
        fused = s[load].fuse(*fused)

        fused, tx = s[load].split(fused, factor=cfg["tile_n"].size[2])
        fused, ty = s[load].split(fused, factor=cfg["tile_m"].size[2])
        fused, tz = s[load].split(fused, factor=cfg["tile_f"].size[2])

        s[load].bind(tz, tvm.te.thread_axis("threadIdx.z"))
        s[load].bind(ty, tvm.te.thread_axis("threadIdx.y"))
        s[load].bind(tx, tvm.te.thread_axis("threadIdx.x"))

    # max unroll
    s[batch_matmul_op].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[batch_matmul_op].pragma(kernel_scope, "unroll_explicit", False)

    return s

STEP 5:relay op strategy

  • relay 通过 strategy 类为每个算子选择合适的 compute 和 schedule
  • 相关文件:python/tvm/relay/op/strategy/cuda.py
  • 我们在这里增加对 batch_matmul_int8 算子计算和调度的选择策略:
@batch_matmul_strategy.register(["cuda", "gpu"])
def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
    """batch_matmul cuda strategy"""
    strategy = _op.OpStrategy()
    x, y = inputs
    if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32":
        strategy.add_implementation(
            wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True),
            wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8),
            name="batch_matmul_int8.cuda",
            plevel=10,
        )
    else:
        strategy.add_implementation(
            wrap_compute_batch_matmul(topi.cuda.batch_matmul),
            wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
            name="batch_matmul.cuda",
            plevel=10,
        )
    ...
  • 同时,由于我们实现的 batch_matmul_int8 的计算需要 out_dtype 作为参数,因此也需要同时更改 python/tvm/relay/op/strategy/generic.py 文件中的 wrap_compute_batch_matmul 函数,增加一个 need_out_dtype 的参数:
# batch_matmul
def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, need_out_dtype=False):
    """wrap batch_matmul topi compute"""

    def _compute_batch_matmul(attrs, inputs, out_type):
        args = [inputs[0], inputs[1], out_type.shape]
        if need_auto_scheduler_layout:
            args.append(get_auto_scheduler_rewritten_layout(attrs))
        if need_out_dtype:
            args.append(out_type.dtype)
        return [topi_compute(*args)]

    return _compute_batch_matmul

测试速度

通过以上五个步骤,给定一个深度模型,通过 tvm 的量化 pass,我们可以得到将 batch_matmul 算子从浮点数计算转换为低比特定点数计算的模型,简单示例如下:

# onnx 模型 -> tvm realy 模型
G = onnx.load(open("/path/of/onnx", "rb"))
mod, params = tvm.relay.frontend.from_onnx(G, {"data": [1, 3, 224, 224]})

# tvm 量化 pass,这里的 qconfig 可以根据具体情况定义,这里仅使用 global_scale,也可以通过构造校准数据集进行量化校准
with tvm.relay.quantize.qconfig(calibrate_mode="global_scale", global_scale=8.0, skip_dense_layer=False, skip_conv_layers=[0]):
    mod = tvm.relay.quantize.quantize(mod, params)

# 抽取 autotvm 的 task,可以看到 batch_matmul 算子的 task 已经从 batch_matmul.cuda 变成了 batch_matmul_int8.cuda 了
tasks = tvm.autotvm.task.extract_from_program(mod['main'], target='cuda', params=params)
for i, task in enumerate(tasks):
    prefix = "[Task %2d/%2d %s] " % (i + 1, len(tasks), task.name)
    print(prefix, task)

更具体的示例可以参看 tests/python/nightly/quantization/test_quantization_accuracy_for_vit.py 的测试用例。

在 tvm 中增加了对 batch_matmul 的量化支持后,我们首先测试了 ViT 模型在 tvm 上量化后的速度表现:

后端 B16_224 B32_224
tune500_int8 (before PR) 59.41 ms 9.14 ms
tune500_int8 (after PR) 12.77 ms 4.38 ms

可以看到,PR 之后的 ViT 量化模型在 tvm 上有了比较大的速度提升。

此外,我们还测试了 tvm 量化模型的精度表现,在 imagenet 验证集的 5 万张图片上测试 accuracy top1/top5 的结果,在量化校准时我们使用 64 张图片作为校准集合:

精度来源 量化校准算法 B16_224 B32_224
paper 77.91/- 73.38/-
tvm-fp32 78.49/93.68 73.27/90.45
tvm-int8 kl_divergence 63.52/84.72 66.26/86.87
tvm-int8 percentile_0.99999 72.92/90.63 72.78/90.21
tvm-int8 percentile_0.9999 75.25/91.96 -/-

从表中可以看到,ViT 的量化模型在 kl 散度的校准方法下模型精度下降严重,而采用 percentile 方法时,B32_224 模型的量化精度基本可以接近 fp32 模型的水平,而 B16_224 模型的量化精度相比 kl 也有较大的提升,但其相对全精度模型精度下降还较多,后续还可以尝试更多的校准算法以得到更高精度的量化模型。

总结

本篇文章以优化 ViT 模型为目标,在 tvm 中尝试提升 ViT 量化模型的速度,增加了 batch_matmul 算子的量化支持,并有效提升了模型的运行效率,同时梳理了在 tvm 中添加一个算子量化支持的具体步骤,为开源社区的发展贡献了一点微薄的力量,这也是组内之前那片量化文章预告之后第一个相对较大的社区贡献:

https://zhuanlan.zhihu.com/p/355598250​zhuanlan.zhihu.com![图标](https://pic3.zhimg.com/v2-9bf2ac6cdfcf9ef415f448fd83e49eee_ipico.jpg)

另外,这里是组里量化大佬们的技术直播回放链接,里面介绍了组里 ICLR 2021 的一篇工作,以及量化能够实质性加速模型的几种路径:

TechBeat​www.techbeat.net

欢迎对这方面感兴趣的同学加入我们,可直接投递简历到 liuliang1\@sensetime.com,也可参看以下招聘详情:

https://zhuanlan.zhihu.com/p/357332343​zhuanlan.zhihu.com![图标](https://pic2.zhimg.com/v2-229eae5e704b1d0ddc3232d05feed371_ipico.jpg)

  • 0
  • 0
  • 54
收藏
暂无评论