stridedslice算子定义

strided_slice算子是tensorflow独有的一个算子,用于从输入中按照一定规律挑选数据,功能非常强大,官方文档定义的算子原型如下所示。

1
2
3
4
tf.strided_slice(
    input_, begin, end, strides=None, begin_mask=0, end_mask=0, ellipsis_mask=0,
    new_axis_mask=0, shrink_axis_mask=0, var=None, name=None
)

一般来讲,begin,end,strides是1D矢量,个数必须相同,不能多于input的维数,mask参数是与begin,end,strides对应的位变量,用于精细控制挑选数据的行为。

参数

从上面定义可以看出,strided_slice的参数其实蛮多的,细看的话,其实可以分为2类

  • input,begin,end,strides
  • begin_mask,end_mask,ellipsis_mask,new_axis_mask, shrink_axis_mask

首先我们看第一类,input,begin,end,strides这几个参数定义了strided_slice算子的基本功能,即从input中按照从begin开始,strides为步进挑选数据,直到到达end,选择的区间是左闭右开的,意味着并不包含end。这几个参数的含义很容易理解,并不存在什么疑义,本文就不再细讲了。

本文重点关注后面的这5个mask参数,首先我们看官方文档的介绍:

1
2
3
4
5
6
7
If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range in that dimension is used instead. end_mask works analogously, except with the end range.

If the ith bit of ellipsis_mask is set, as many unspecified dimensions as needed will be inserted between other dimensions. Only one non-zero bit is allowed in ellipsis_mask.

If the ith bit of new_axis_mask is set, then begin, end, and stride are ignored and a new length 1 dimension is added at this point in the output tensor.

If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1, taking on the value at index begin[i]. end[i] and strides[i] are ignored in this case.

相信绝大多数人看了上文的解释,觉得这5个mask功能还是很明确的,很好理解的,但是,一执行TensorFlow,就会发现实际情况远不是字面的那样,很多场景运行结果完全无法解释清楚,google搜索也没有什么实质的解释,一头雾水。

计算过程

下面我们详细分析一下这5个mask之间的关系以及详细计算过程。首先,我们要明确一点,这5个mask并不是独立的,其实它们是互相影响的,尤其是后面3个mask:ellipsis_mask,new_axis_mask,shrink_axis_mask。其次,要注意到,mask的位数是可以少于input的维数的,例如,【6,3,4,5】维的输入数据,你可以指定begin为【0,1】,end为【3,5】,strides为【1,2】,缺失的维度默认全选,这种模式叫稀疏模式,反之,如果begin数据个数和input的维度相同,叫稠密模式。

有了这些共识,我们接下来就可以分析这几个mask的实际功能了。

  • begin_mask,end_mask功能比较简单明确,就是单纯的忽略对应位的begin,end设置,要注意的是,有ellipsis_mask的时候,索引位置需要考虑ellipsis_mask,可能需要适当延后。
  • ellipsis_mask:这个mask最多只能设置一位,如果你注意到了,文档上提到过,这个mask设置的时候,as many unspecified dimensions as needed will be inserted between other dimensions,该怎么理解呢?其实是说,如果某位有ellipsis_mask,那么这一个维度将全选,其后未指定的维度也尽可能的全选。稠密模式每个维度都指定了,ellipsis_mask只能影响第i个维度,稀疏模式的话,begin比input少几个维度,elipsis_mask除了影响第i维度,之后的几个维度也会是全选状态,换言之,ellipsis_mask相当于开启全选模式,直到遇到用户指定才退出这种状态,它影响的是一个区间,这也是为什么它只能最多有一位的原因。
  • new_axis_mask:简单来说,就是在指定位上增加维度为1的shape,但是,这个增加过程是重复计算的,例如input为【6,3,4,5】,new_axis_mask为0b1111的话,第一次发现最低位为1,索引0增加1,shape变为【1,6,3,4,5】,第二次发现第二位为1,在索引1增加1,shape变为【1,1,6,3,4,5】,第三次发现第三位为1,在索引2增加1,shape变为【1,1,1,6,3,4,5】,第四次发现第四位为1,在索引3增加1,shape变为【1,1,1,1,6,3,4,5】,常规模式这样就可以了,但是,加上ellipsis_mask以后,情况就变得复杂了,因为新增axis以后,稠密模式就不存在了,一定会变成稀疏模式,ellipsis_mask位置会全选一个区间,这时新增的索引就需要考虑ellipsis_mask的位置。特别注意一点,ellipsis_mask和new_axis_mask同时设置的位,new_axis_mask位会失效。ellipsis_mask和shrink_axis_mask同时设置的位,shrink_axis_mask位会失效。
  • shrink_axis_mask:把对应索引处维度强制降为1。

综上,其实问题全是因为ellipsis_mask影响的是一个区间引起的,它可能导致其它mask的索引被迫延后引起的。分析了这么半天,我们还是举个例子来看一下详细的计算过程吧。

假设输入是【6,3,4,10】,tf.strided_slice(input_data, [0, 0, 2, 2], [3, 2, 4, 8], [1, 1, 1, 1], new_axis_mask=0b1001, shrink_axis_mask=4, ellipsis_mask=8),因为最高位ellipsis_mask和new_axis_mask同时设置上了,所以new_axis_mask其实等价于0b0001,新增完axis,shape变为【1,6,3,4,10】,然后列下表计算

axis 1 6 3 4(此位ellipsis,将替换原始输入) 10(此位没有对应数据,全选)
begin 0 0 2 2–>0 0
end 3 2 4 8–>4 10
strides 1 1 1 1–>1 1
output 1 2 1 4 10

最后把结果第三位删掉(强制变成维度为1),所以最终的结果是【1,2,4,10】

考虑另一个例子,输入不变,还是【6,3,4,10】,tf.strided_slice(input_data, [0, 0, 2, 2], [3, 2, 4, 8], [1, 1, 1, 1], new_axis_mask=0b1001, shrink_axis_mask=4, ellipsis_mask=4),ellipsis_mask和shrink_axis_mask同时存在,shrink_axis_mask将失效,先做new_axis,第一次发现第0位是1,shape变为【1,6,3,4,10】,第二次发现第三位是1,这个位置比ellipsis的位置要远,input总共添加2维,将变成6维,begin只有4维,ellipsis将从第3位开始全选,维持1+(6-4)维,所以这个插入的位置只能是最后一个,也就是说,shape将变为【1,6,3,4,10,1】,然后列表计算

axis 1 6 3(ellipsis) 4(ellipsis) 10(ellipsis) 1
begin 0 0 2–>0 0 0 2
end 3 2 4–>3 4 10 8
strides 1 1 1 1 1 1
output 1 2 3 4 10 1

结果是【1,2,3,4,10,1】

上述计算过程的核心是列表计算,核心思想是利用下述原理:

  1. 添加的维度为1其实并不影响存贮的数据,只是view或者解释发生了改变。
  2. 添加的维度不管怎么计算,只能输出1,该位对结果没影响,选择数据时可以跳过新增的维度
  3. shrink_mask位可以认为强制输出第一个元素,可以认为没有

主要步骤

模块 功能
remove_conflict 存在ellipsis_mask的话,shrink_mask对应位无效,new_axis_mask对应位也忽略,修改输入mask数据
process_new_axis_mask 对inputShape这个vector循环做insert 1的操作,增加时需要考虑ellipsis区间,实现维度增加1,增加后,inputShape最多变为原先2倍。
process_elippsis_mask 强制设置对应ellipsis区间begin, end, stride数据为【0,N, 1】
fill_missing_shape 填补缺失的begin,end,stride,数据选择inputShape对应的最大值
clamp(begin) begin为负则加N,使其落在【0, N-1】区间,使输入标准化
clamp(end) end为负则加N,使其落在【0, N-1】区间,使输入标准化
process_begin_mask 有begin_mask的话,正向stride>0,设begin设为0,反向stride<0,设begin为N
process_end_mask 有end_mask的话,正向stride>0,设end为N,反向stride<0,设end为0
process_shrink_mask 强制end,stride为【begin+1,1】,对应位只输出维度1,等价于可以忽略
compute_output_shape 根据inputShape, begin, end, stride计算出来新的outputShape:列等式,逐位计算,inputShape为1的话,只能输出1,结果删掉shrink_mask指定的位
compute_pick_shape 基本等同outputShape计算,会跳过新增的维度,确保和最初的输入维度相同,不用考虑shrink的维度
pick_data 根据pickShape对应的begin,end,stride挑选数据出来,挑选维度最多其实与原先维度相同,新增的1可以忽略掉。