背景

tf,PyTorch,numpy的广播其实和python的是一样的,算子支持广播的话可以简化代码(减少准备数据的代码),减少内存消耗。

例如,一个3 * 3的张量,减去一个常量,如果不支持广播,需要先将常量复制成3 * 3的,然后2个张量做减法。支持广播的话,3 * 3的张量可以直接减去1的张量。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
>>> a = tf.constant([[1,2,3],[4,5,6],[7,8,9]])
>>> b = tf.constant([2])
>>> c = tf.subtract(a,b)
>>>
>>>
>>> a
<tf.Tensor: id=3, shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]], dtype=int32)>
>>> b
<tf.Tensor: id=4, shape=(1,), dtype=int32, numpy=array([2], dtype=int32)>
>>> c
<tf.Tensor: id=5, shape=(3, 3), dtype=int32, numpy=
array([[-1,  0,  1],
       [ 2,  3,  4],
       [ 5,  6,  7]], dtype=int32)>
>>>

shape兼容检查

如果两个张量的后缘维度(从末尾开始算起的维度)的轴长度相符或其中一方的长度为1,则认为它们是广播兼容的。广播会在缺失维度和(或)轴长度为1的维度上进行。

image-20200227110311665

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
	    const int maxDimensions = input0->dimensions();
        const int diffDimension = input0->dimensions() - input1->dimensions();

        // else broadcast
        // 从右侧开始计算
        for (int i = maxDimensions-1; i >=0 ; --i) {
            // 将两个输入的shape上下排在一起,右侧对齐,第二个shape缺失的部分补位1
            // 然后从右侧开始逐列检查,如果上下不等,并且没有1,那么算是不兼容,这样没法做广播
            // 也就是说 [3,4,6]与[4,6]是兼容的,但是[3,4,6]与[2,6]是不兼容的。
            // 不等的时候,必须其中一个为1,否则就算是不兼容
            auto input0Length = input0->length(i);
            auto input1Length = 1;
            if (i >= diffDimension) {
                input1Length = input1->length(i-diffDimension);
            }
            if (input0Length != input1Length && input1Length != 1 && input0Length != 1) {
                MNN_PRINT("Don't support broadcast for binaryOp, i0=%d, i1=%d\n", input1Length, input0Length);
                return false;
            }
            // 更新输出shape对应位为二者中大的那个
            buffer.dim[i].extent = std::max(input0Length, input1Length);
        }

广播计算

核心是计算stride,对于可以广播的维度,它对应的shape值是1,stride为0,这样其实索引的还是原先的数据。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
// 输入shape不同,通过广播方式计算,核心是计算dim及stride,广播最多支持6维
#define MAX_DIM 6
            MNN_ASSERT(output->dimensions() <= MAX_DIM);
            int dims[MAX_DIM];
            int stride[MAX_DIM];
            int iStride0[MAX_DIM];
            int iStride1[MAX_DIM];
            for (int i = MAX_DIM - 1; i >= 0; --i) {
                dims[i]     = 1;
                stride[i]   = 0;
                iStride0[i] = 0;
                iStride1[i] = 0;
                int input0I = i - (output->dimensions() - input0->dimensions());
                int input1I = i - (output->dimensions() - input1->dimensions());
                if (i < output->dimensions()) {
                    dims[i]   = output->length(i);
                    stride[i] = output->stride(i);
                }
                if (input0I >= 0 && input0->length(input0I) != 1) {
                    iStride0[i] = input0->stride(input0I);
                }
                if (input1I >= 0 && input1->length(input1I) != 1) {
                    iStride1[i] = input1->stride(input1I);
                }
            }
            for (int w = 0; w < dims[5]; ++w) {
                auto ow  = outputData + w * stride[5];
                auto i0w = input0Data + w * iStride0[5];
                auto i1w = input1Data + w * iStride1[5];
#define PTR(x, y, i)                      \
    auto o##x  = o##y + x * stride[i];    \
    auto i0##x = i0##y + x * iStride0[i]; \
    auto i1##x = i1##y + x * iStride1[i]

                for (int v = 0; v < dims[4]; ++v) {
                    PTR(v, w, 4);
                    for (int u = 0; u < dims[3]; ++u) {
                        PTR(u, v, 3);
                        for (int z = 0; z < dims[2]; ++z) {
                            PTR(z, u, 2);
                            for (int y = 0; y < dims[1]; ++y) {
                                PTR(y, z, 1);
                                for (int x = 0; x < dims[0]; ++x) {
                                    PTR(x, y, 0);
                                    *ox = static_cast<Tout>(f(*i0x, *i1x));
                                }
                            }
                        }
                    }
                }
            }
#undef MAX_DIM
#undef PTR