HLO Tensor Tiling, Pruning

Example

HLO Function

%fused_computation.3461.clone (param_0.15226: f32[3,35,1024], param_1.21991: f32[3,35,1024], param_2.18880: s32[], param_3.11384: s32[], param_4.4941: f32[3,35], param_5.2397: f32[3,35], param_6.2244: s32[], param_7.2283: f32[3,35,1024], param_8.2008: f32[3,35,1024], param_9.1351: f32[3,35,1024], param_10.643: s32[], param_11.421: s32[], param_12.202: s32[], param_13.379: f32[3,35], param_14.480: f32[3,35]) -> f32[3,35,1024] {
  %param_9.1351 = f32[3,35,1024]{2,1,0} parameter(9)
  %param_14.480 = f32[3,35]{1,0} parameter(14)
  %negate.1964 = f32[3,35]{1,0} negate(f32[3,35]{1,0} %param_14.480), metadata={op_type="Neg" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/encdec_attention/layer_prepostprocess/layer_norm/SquaredDifference_grad/Neg"}
  %param_13.379 = f32[3,35]{1,0} parameter(13)
  %negate.1965 = f32[3,35]{1,0} negate(f32[3,35]{1,0} %param_13.379), metadata={op_type="Neg" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/encdec_attention/layer_prepostprocess/layer_norm/sub_grad/Neg"}
  %add.9068 = f32[3,35]{1,0} add(f32[3,35]{1,0} %negate.1964, f32[3,35]{1,0} %negate.1965), metadata={op_type="AddN" op_name="training/gradients/AddN_4"}
  %constant.8355 = pred[] constant(false), metadata={op_type="FloorDiv" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_prepostprocess/layer_norm/Mean_1_grad/floordiv_1"}
  %constant.8354 = s32[] constant(840), metadata={op_type="Size" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_prepostprocess/layer_norm/Mean_1_grad/Prod_1"}
  %param_12.202 = s32[] parameter(12)
  %maximum.1386 = s32[] maximum(s32[] %constant.8354, s32[] %param_12.202), metadata={op_type="Maximum" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/encdec_attention/layer_prepostprocess/layer_norm/Mean_grad/Maximum_1"}
  %constant.8353 = s32[] constant(0), metadata={op_type="FloorDiv" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/Mean_grad/floordiv_1"}
  %compare.2227 = pred[] compare(s32[] %maximum.1386, s32[] %constant.8353), direction=LT, metadata={op_type="FloorDiv" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/encdec_attention/layer_prepostprocess/layer_norm/Mean_grad/floordiv_1"}
  %compare.2228 = pred[] compare(pred[] %constant.8355, pred[] %compare.2227), direction=NE, metadata={op_type="FloorDiv" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/encdec_attention/layer_prepostprocess/layer_norm/Mean_grad/floordiv_1"}
  %param_10.643 = s32[] parameter(10)
  %param_11.421 = s32[] parameter(11)
  %select.606 = s32[] select(pred[] %compare.2228, s32[] %param_10.643, s32[] %param_11.421), metadata={op_type="FloorDiv" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/encdec_attention/layer_prepostprocess/layer_norm/Mean_grad/floordiv_1"}
  %convert.2718 = f32[] convert(s32[] %select.606), metadata={op_type="Cast" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/encdec_attention/layer_prepostprocess/layer_norm/Mean_grad/Cast"}
  %broadcast.18365 = f32[3,35]{1,0} broadcast(f32[] %convert.2718), dimensions={}
  %divide.3442 = f32[3,35]{1,0} divide(f32[3,35]{1,0} %add.9068, f32[3,35]{1,0} %broadcast.18365), metadata={op_type="RealDiv" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/encdec_attention/layer_prepostprocess/layer_norm/Mean_grad/truediv"}
  %broadcast.18367 = f32[3,35,1024]{2,1,0} broadcast(f32[3,35]{1,0} %divide.3442), dimensions={0,1}
  %add.9069 = f32[3,35,1024]{2,1,0} add(f32[3,35,1024]{2,1,0} %param_9.1351, f32[3,35,1024]{2,1,0} %broadcast.18367), metadata={op_type="AddN" op_name="training/gradients/AddN_5"}
  %param_8.2008 = f32[3,35,1024]{2,1,0} parameter(8)
  %add.9070 = f32[3,35,1024]{2,1,0} add(f32[3,35,1024]{2,1,0} %add.9069, f32[3,35,1024]{2,1,0} %param_8.2008), metadata={op_type="AddN" op_name="training/gradients/AddN_5"}
  %param_7.2283 = f32[3,35,1024]{2,1,0} parameter(7)
  %add.9071 = f32[3,35,1024]{2,1,0} add(f32[3,35,1024]{2,1,0} %add.9070, f32[3,35,1024]{2,1,0} %param_7.2283), metadata={op_type="AddN" op_name="training/gradients/AddN_5"}
  %param_5.2397 = f32[3,35]{1,0} parameter(5)
  %negate.1966 = f32[3,35]{1,0} negate(f32[3,35]{1,0} %param_5.2397), metadata={op_type="Neg" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/self_attention/layer_prepostprocess/layer_norm/SquaredDifference_grad/Neg"}
  %param_4.4941 = f32[3,35]{1,0} parameter(4)
  %negate.1967 = f32[3,35]{1,0} negate(f32[3,35]{1,0} %param_4.4941), metadata={op_type="Neg" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/self_attention/layer_prepostprocess/layer_norm/sub_grad/Neg"}
  %add.9072 = f32[3,35]{1,0} add(f32[3,35]{1,0} %negate.1966, f32[3,35]{1,0} %negate.1967), metadata={op_type="AddN" op_name="training/gradients/AddN_7"}
  %param_6.2244 = s32[] parameter(6)
  %maximum.1385 = s32[] maximum(s32[] %constant.8354, s32[] %param_6.2244), metadata={op_type="Maximum" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/self_attention/layer_prepostprocess/layer_norm/Mean_grad/Maximum_1"}
  %compare.2225 = pred[] compare(s32[] %maximum.1385, s32[] %constant.8353), direction=LT, metadata={op_type="FloorDiv" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/self_attention/layer_prepostprocess/layer_norm/Mean_grad/floordiv_1"}
  %compare.2226 = pred[] compare(pred[] %constant.8355, pred[] %compare.2225), direction=NE, metadata={op_type="FloorDiv" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/self_attention/layer_prepostprocess/layer_norm/Mean_grad/floordiv_1"}
  %param_2.18880 = s32[] parameter(2)
  %param_3.11384 = s32[] parameter(3)
  %select.605 = s32[] select(pred[] %compare.2226, s32[] %param_2.18880, s32[] %param_3.11384), metadata={op_type="FloorDiv" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/self_attention/layer_prepostprocess/layer_norm/Mean_grad/floordiv_1"}
  %convert.2717 = f32[] convert(s32[] %select.605), metadata={op_type="Cast" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/self_attention/layer_prepostprocess/layer_norm/Mean_grad/Cast"}
  %broadcast.18368 = f32[3,35]{1,0} broadcast(f32[] %convert.2717), dimensions={}
  %divide.3443 = f32[3,35]{1,0} divide(f32[3,35]{1,0} %add.9072, f32[3,35]{1,0} %broadcast.18368), metadata={op_type="RealDiv" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23/self_attention/layer_prepostprocess/layer_norm/Mean_grad/truediv"}
  %broadcast.18369 = f32[3,35,1024]{2,1,0} broadcast(f32[3,35]{1,0} %divide.3443), dimensions={0,1}
  %add.9073 = f32[3,35,1024]{2,1,0} add(f32[3,35,1024]{2,1,0} %add.9071, f32[3,35,1024]{2,1,0} %broadcast.18369), metadata={op_type="AddN" op_name="training/gradients/AddN_8"}
  %param_1.21991 = f32[3,35,1024]{2,1,0} parameter(1)
  %add.9074 = f32[3,35,1024]{2,1,0} add(f32[3,35,1024]{2,1,0} %add.9073, f32[3,35,1024]{2,1,0} %param_1.21991), metadata={op_type="AddN" op_name="training/gradients/AddN_8"}
  %param_0.15226 = f32[3,35,1024]{2,1,0} parameter(0)
  ROOT %add.9075 = f32[3,35,1024]{2,1,0} add(f32[3,35,1024]{2,1,0} %add.9074, f32[3,35,1024]{2,1,0} %param_0.15226), metadata={op_type="AddN" op_name="training/gradients/AddN_8"}
}

Generate Graph

HLO/Untitled.png

Propagation

Propagate::DFS

High Level Description

a modified version of DFS

BOTE Analysis

Low Level Details

  • Parameters:
    • n: (param_name, dimension) pairs, representing the current node
    • m: HashMap>, current
    • m_constraints: HashMap>,
    • v_node: HashMap
    • v_inst: HashMap
  • fn dfs(n, m, m_constraints, v_node, v_inst)
    • if already_visited(n): return with error
    • v_node.insert(n)
    • m[n.0].insert(n.1)
    • if visited[m] return None, else mark m as visited
    • let suc_once = -1
    • for e in edges starting from n
      • make sure there is no invalid edge
        • valid = true
        • for (i, c) in e
          • if i in v_inst and v_inst[i] ≠ c, then valid = false
        • if not valid, continue
      • let nw = the other endpoint of e
      • if nw in v_node, continue
      • if nw.0 in m_constraints and nw.1 not in m_constraints[nw.0], continue
      • v_inst_clone make a copy of v_inst
      • for (i, c) in e
        • v_inst_clone.insert(i, c) if i not in v_inst
      • let new_map = dfs(nw, m.clone(), m_constraints, v_node.clone(), v_inst_clone)
      • if new_map is None
        • if suc_once == -1, suc_once = 0
        • continue
      • suc_once = 1
      • m = merge m with new_map
    • return None if suc_once == 0, else m

You'll only receive email when 荣懿的草稿本 publishes a new post

More from 荣懿的草稿本