Tensorflow Stridedslice算子是怎么计算的
文章目录
stridedslice算子定义
strided_slice算子是tensorflow独有的一个算子,用于从输入中按照一定规律挑选数据,功能非常强大,官方文档定义的算子原型如下所示。
|
|
一般来讲,begin,end,strides是1D矢量,个数必须相同,不能多于input的维数,mask参数是与begin,end,strides对应的位变量,用于精细控制挑选数据的行为。
参数
从上面定义可以看出,strided_slice的参数其实蛮多的,细看的话,其实可以分为2类
- input,begin,end,strides
- begin_mask,end_mask,ellipsis_mask,new_axis_mask, shrink_axis_mask
首先我们看第一类,input,begin,end,strides这几个参数定义了strided_slice算子的基本功能,即从input中按照从begin开始,strides为步进挑选数据,直到到达end,选择的区间是左闭右开的,意味着并不包含end。这几个参数的含义很容易理解,并不存在什么疑义,本文就不再细讲了。
本文重点关注后面的这5个mask参数,首先我们看官方文档的介绍:
|
|
相信绝大多数人看了上文的解释,觉得这5个mask功能还是很明确的,很好理解的,但是,一执行TensorFlow,就会发现实际情况远不是字面的那样,很多场景运行结果完全无法解释清楚,google搜索也没有什么实质的解释,一头雾水。
计算过程
下面我们详细分析一下这5个mask之间的关系以及详细计算过程。首先,我们要明确一点,这5个mask并不是独立的,其实它们是互相影响的,尤其是后面3个mask:ellipsis_mask,new_axis_mask,shrink_axis_mask。其次,要注意到,mask的位数是可以少于input的维数的,例如,【6,3,4,5】维的输入数据,你可以指定begin为【0,1】,end为【3,5】,strides为【1,2】,缺失的维度默认全选,这种模式叫稀疏模式,反之,如果begin数据个数和input的维度相同,叫稠密模式。
有了这些共识,我们接下来就可以分析这几个mask的实际功能了。
- begin_mask,end_mask功能比较简单明确,就是单纯的忽略对应位的begin,end设置,要注意的是,有ellipsis_mask的时候,索引位置需要考虑ellipsis_mask,可能需要适当延后。
- ellipsis_mask:这个mask最多只能设置一位,如果你注意到了,文档上提到过,这个mask设置的时候,as many unspecified dimensions as needed will be inserted between other dimensions,该怎么理解呢?其实是说,如果某位有ellipsis_mask,那么这一个维度将全选,其后未指定的维度也尽可能的全选。稠密模式每个维度都指定了,ellipsis_mask只能影响第i个维度,稀疏模式的话,begin比input少几个维度,elipsis_mask除了影响第i维度,之后的几个维度也会是全选状态,换言之,ellipsis_mask相当于开启全选模式,直到遇到用户指定才退出这种状态,它影响的是一个区间,这也是为什么它只能最多有一位的原因。
- new_axis_mask:简单来说,就是在指定位上增加维度为1的shape,但是,这个增加过程是重复计算的,例如input为【6,3,4,5】,new_axis_mask为0b1111的话,第一次发现最低位为1,索引0增加1,shape变为【1,6,3,4,5】,第二次发现第二位为1,在索引1增加1,shape变为【1,1,6,3,4,5】,第三次发现第三位为1,在索引2增加1,shape变为【1,1,1,6,3,4,5】,第四次发现第四位为1,在索引3增加1,shape变为【1,1,1,1,6,3,4,5】,常规模式这样就可以了,但是,加上ellipsis_mask以后,情况就变得复杂了,因为新增axis以后,稠密模式就不存在了,一定会变成稀疏模式,ellipsis_mask位置会全选一个区间,这时新增的索引就需要考虑ellipsis_mask的位置。特别注意一点,ellipsis_mask和new_axis_mask同时设置的位,new_axis_mask位会失效。ellipsis_mask和shrink_axis_mask同时设置的位,shrink_axis_mask位会失效。
- shrink_axis_mask:把对应索引处维度强制降为1。
综上,其实问题全是因为ellipsis_mask影响的是一个区间引起的,它可能导致其它mask的索引被迫延后引起的。分析了这么半天,我们还是举个例子来看一下详细的计算过程吧。
假设输入是【6,3,4,10】,tf.strided_slice(input_data, [0, 0, 2, 2], [3, 2, 4, 8], [1, 1, 1, 1], new_axis_mask=0b1001, shrink_axis_mask=4, ellipsis_mask=8),因为最高位ellipsis_mask和new_axis_mask同时设置上了,所以new_axis_mask其实等价于0b0001,新增完axis,shape变为【1,6,3,4,10】,然后列下表计算
axis | 1 | 6 | 3 | 4(此位ellipsis,将替换原始输入) | 10(此位没有对应数据,全选) |
---|---|---|---|---|---|
begin | 0 | 0 | 2 | 2–>0 | 0 |
end | 3 | 2 | 4 | 8–>4 | 10 |
strides | 1 | 1 | 1 | 1–>1 | 1 |
output | 1 | 2 | 1 | 4 | 10 |
最后把结果第三位删掉(强制变成维度为1),所以最终的结果是【1,2,4,10】
考虑另一个例子,输入不变,还是【6,3,4,10】,tf.strided_slice(input_data, [0, 0, 2, 2], [3, 2, 4, 8], [1, 1, 1, 1], new_axis_mask=0b1001, shrink_axis_mask=4, ellipsis_mask=4),ellipsis_mask和shrink_axis_mask同时存在,shrink_axis_mask将失效,先做new_axis,第一次发现第0位是1,shape变为【1,6,3,4,10】,第二次发现第三位是1,这个位置比ellipsis的位置要远,input总共添加2维,将变成6维,begin只有4维,ellipsis将从第3位开始全选,维持1+(6-4)维,所以这个插入的位置只能是最后一个,也就是说,shape将变为【1,6,3,4,10,1】,然后列表计算
axis | 1 | 6 | 3(ellipsis) | 4(ellipsis) | 10(ellipsis) | 1 |
---|---|---|---|---|---|---|
begin | 0 | 0 | 2–>0 | 0 | 0 | 2 |
end | 3 | 2 | 4–>3 | 4 | 10 | 8 |
strides | 1 | 1 | 1 | 1 | 1 | 1 |
output | 1 | 2 | 3 | 4 | 10 | 1 |
结果是【1,2,3,4,10,1】
上述计算过程的核心是列表计算,核心思想是利用下述原理:
- 添加的维度为1其实并不影响存贮的数据,只是view或者解释发生了改变。
- 添加的维度不管怎么计算,只能输出1,该位对结果没影响,选择数据时可以跳过新增的维度
- shrink_mask位可以认为强制输出第一个元素,可以认为没有
主要步骤
模块 | 功能 |
---|---|
remove_conflict | 存在ellipsis_mask的话,shrink_mask对应位无效,new_axis_mask对应位也忽略,修改输入mask数据 |
process_new_axis_mask | 对inputShape这个vector循环做insert 1的操作,增加时需要考虑ellipsis区间,实现维度增加1,增加后,inputShape最多变为原先2倍。 |
process_elippsis_mask | 强制设置对应ellipsis区间begin, end, stride数据为【0,N, 1】 |
fill_missing_shape | 填补缺失的begin,end,stride,数据选择inputShape对应的最大值 |
clamp(begin) | begin为负则加N,使其落在【0, N-1】区间,使输入标准化 |
clamp(end) | end为负则加N,使其落在【0, N-1】区间,使输入标准化 |
process_begin_mask | 有begin_mask的话,正向stride>0,设begin设为0,反向stride<0,设begin为N |
process_end_mask | 有end_mask的话,正向stride>0,设end为N,反向stride<0,设end为0 |
process_shrink_mask | 强制end,stride为【begin+1,1】,对应位只输出维度1,等价于可以忽略 |
compute_output_shape | 根据inputShape, begin, end, stride计算出来新的outputShape:列等式,逐位计算,inputShape为1的话,只能输出1,结果删掉shrink_mask指定的位 |
compute_pick_shape | 基本等同outputShape计算,会跳过新增的维度,确保和最初的输入维度相同,不用考虑shrink的维度 |
pick_data | 根据pickShape对应的begin,end,stride挑选数据出来,挑选维度最多其实与原先维度相同,新增的1可以忽略掉。 |
文章作者 carter2005
上次更新 2020-05-19