TensorFlow HLO 文本结构化

published @ rongyi.blog: https://rongyi.blog/2020-02-17-hlo-parsing

HLO Text的文件结构

来源: HLO的C++数据结构序列化之后的文本文件

文件全部由一个个函数构成,没有任何其他结构。
每个函数都由函数名、输入参数列表和类型、输出参数类型和函数体构成。
其中的函数体结构类似SSA(Static Single Assignment),每个变量都只会被赋值一次,并且名称唯一。做数据流图DFG的时候可以非常单纯的直接查找变量名找到这个变量被赋值和使用的地方。

每一条SSA指令结构大概如下:

%fusion.8228 = f32[4,32,48,32]{3,2,1,0} fusion(f32[192,1024]{1,0} %dot.2067, f32[] %arg217.0), kind=kLoop, calls=%fused_computation.4684.clone, metadata={op_type="Mul" op_name="transformer/parallel_0_5/transformer/transformer/body/encoder/layer_11/self_attention/multihead_attention/mul"}

大体就是 %var = $type $fn($params), {$metadata…}

结构化处理 attempt 1 (2020/02/11)

Observation

可以观察到每一条指令的数据流动都是从等号从右往左,所以可以尝试直接使用Python对文本做字符串处理,大概思路就是

  1. 按照等号split每一条指令
  2. 等号左边处理%var1,作为左操作数
  3. 等号右边处理%var2, %var3, 作为右操作数
  4. 数据流关系就是左操作数依赖于右操作数

Implementation

代码实现如下

class block:
    def __init__(self):
        self.name = “”
        self.firstline = “”
        self.params = []
        self.body = []
        self.calls = []

class node:
    def __init__(self):
        self.id = “”
        self.label = “”
class edge:
    def __init__(self):
        self.source = “”
        self.target = “”
class graph:
    def __init__(self):
        self.nodes = []
        self.edges = []

result = []

def process_body_line(s):
    # into calls
    ret = []
    call_fn = ["calls=", "to_apply="]
    for w in call_fn:
        while w in s:
            start = s.find(w) + len(w)
            end = s.find(",", start)
            ret.append(s[start:end])
            s = s.replace(w, ''.join(reversed(w)))
    return ret

def process_first_line_into_args(s):
    # into params
    ret = []
    param_end = s.find(->)
    if param_end != -1:
        s = s[0 : param_end - 2]
    param_start = s.find(() + 1
    s = s[param_start : param_end + 1]
    params = s.split(", ")
    for x in params:
        ret.append(x[0:x.find(":")])
    # param_end = s.find(")", param_end) - 1
    return ret

def process_first_line_into_name(s):
    name_end = s.find(" (")
    name = s[0:name_end]
    name = name.replace("ENTRY ", "")
    return name

l = 0
while l < len(lines):
    # print(l)
    line = lines[l]
    if len(line) < 2:
        l = l + 1
        continue
    if line[:2] ==   :
        print(Unhandled Situation, printing surround lines…”)
        print(lines[l - 1], lines[l + 1])
        exit
    if line[0] != " ":
        f = block()
        f.firstline = line.replace("\n", "")
        f.params = process_first_line_into_args(f.firstline)
        f.name = process_first_line_into_name(f.firstline)
        # print(process_first_line_into_name(f.firstline))
        l = l + 1
        line = lines[l]
        while line[0] != }:
            f.calls = f.calls + process_body_line(line)
            f.body.append(line.replace(\n, “”))
            l = l + 1
            line = lines[l]
        result.append(f)
    l = l + 1
    # if l % 10 == 0:
    #     print("Currently l = ", l)

def parse_fn_line(s):
    i = 0
    ret = []
    while i < len(s):
        if s[i] == '%':
            new_var = "%"
            i = i + 1
            while s[i] != ' ' and s[i] != ')' and s[i] != ',':
                new_var += s[i]
                i = i + 1
            ret.append(new_var)
        i = i + 1
    return ret

def parse_fn_dfg(blk):
    # print("parsing fn", blk.name)
    variables = []
    for x in blk.body:
        variables.append(parse_fn_line(x))
    # print(variables)
    g = graph()
    created = set()
    for l in variables:
        for x in l:
            if x in created:
                continue
            n = node()
            n.id = x
            n.label = x#shorten_name(x)
            g.nodes.append(n)
            created.add(x)
    for l in variables:
        if len(l) <= 1:
            continue
        for x in l[1:]:
            e = edge()
            e.source = x
            e.target = l[0]
            g.edges.append(e)
    def dumper(obj):
        try:
            return obj.toJSON()
        except:
            return obj.__dict__
    return json.dumps(g, default=dumper, indent=2)

Limitation

马上问题就来了,HLO指令应该还是偏灵活,
光生成DFG碰到的bad case就有

  • 左操作数不止一个
  • 有可能没有右操作数

今天(134:23)在尝试做Variable Propagation的时候碰到了更多的问题,如

  • 需要读取variable的类型,类似f32[4,32,48,32]{3,2,1,0}这种信息在尝试切分矩阵的时候是必要的
  • 需要识别右边函数的名称,以及那些右操作数会被传入该函数,对不同函数切分矩阵的处理不同
  • 个别操作会在metadata里写重要信息…比如Slice会将slice的dimension放在后面

结构化处理 attempt 2 (2020/02/12)

Observation

尝试了一早上使用Python字符串匹配处理HLO文本,发现会触发各种Corner cases,比如

  • 右操作数可以直接是一个number
  • metadata可以是一个dict
  • metadata可以是一个[a:b] 的数组
  • metadata可以是一个字符串
  • 函数返回值得类型可以是一个数组:返回多个变量的函数

Implementation

下午改用另一种思路,使用词法和语法分析把这个SSA form当成 LL(k) 语法抽象生成语法树
词法比较简单,语法也能套LL1的大多数结构

考虑AST结构如下,内嵌EBNF和语法定义

var HLOLexer = lexer.Must(ebnf.New(`
Comment = ("#" | "//") { "\u0000"…"\uffff"-"\n" } .
Ident = (alpha | "_") { "." | "_" | "-" | alpha | digit } .
String = "\"" {Ident | "/"} "\"" .
VarName = "%" Ident .
Number = { "-" } ("." | digit | "inf") {"." | digit} .
Whitespace = " " | "\t" | "\n" | "\r" .
Rightarrow = "->" .
Assign = "=" .
Punct = "!"…"/" | ":"…"@" | "["…"_" | "{"…"~" .
alpha = "a"…"z" | "A"…"Z" .
digit = "0"…"9" .
`))

type HLO struct {
    Functions []*HLOFunction `@@*`
}

type HLOFunction struct {
    Name        string         `("ENTRY")? @VarName`
    Params      []*Param       `"(" [ @@ { "," @@ } ] ")"`
    ReturnTypes []*Type        `"->" ( "(" [ @@ { "," @@ } ] ")" | @@)`
    Body        []*Instruction `"{" @@ {@@} "}"`
}

type Instruction struct {
    VarName string        `("ROOT")? @VarName "="`
    Fn      *FunctionCall `@@`
    Meta    []*Meta       `{ "," @@ }`
}

type FunctionCall struct {
    ReturnTypes []*RichType  `(@@ | "(" @@ { "," @@ } ")" )`
    Name        string       `@Ident`
    Params      []*RichParam `"(" [ @@ { "," @@ } ] ")"`
}

type Meta struct {
    Key        string  `@Ident "="`
    Value      string  `(@Ident|@VarName|@Number)?`
    DictValue  []*Dict `("{" { @@ } "}")?`
    ListNums   []int   `("{" @Number {"," @Number } "}")?`
    ListSlices []Slice `("{" @@ {"," @@ } "}")?`
}

type Dict struct {
    Key   string `@Ident "="`
    Value string `@String | @Ident`
}

type Slice struct {
    Start int `"[" @Number ":"`
    End   int `@Number "]"`
}

type Param struct {
    Name string `@Ident ":"`
    Type *Type  `@@`
}

type Type struct {
    DataType   string `@Ident`
    Dimensions []int  `"[" [ @Number { "," @Number } ] "]"`
}

type RichParam struct {
    Type *RichType `(@@)?`
    Name string    `@VarName | @Number | @Ident`
}

type RichType struct {
    VarType string `@Ident`
    VarDim  []int  `"[" [ @Number { "," @Number } ] "]" ("{" [ @Number { "," @Number } ] "}")?`
}

Result

解析事例函数如下

%fused_computation.19.clone (param_0.16672: f32[4,49,1024], param_1.23221: f32[196,1024]) -> f32[1024] {
  %param_1.23221 = f32[196,1024]{1,0} parameter(1)
  %reshape.13330 = f32[4,49,1024]{2,1,0} reshape(f32[196,1024]{1,0} %param_1.23221), metadata={op_type="Reshape" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23_1/ffn/conv1/Tensordot/Reshape_grad/Reshape"}
  %param_0.16672 = f32[4,49,1024]{2,1,0} parameter(0)
  %multiply.14985 = f32[4,49,1024]{2,1,0} multiply(f32[4,49,1024]{2,1,0} %reshape.13330, f32[4,49,1024]{2,1,0} %param_0.16672), metadata={op_type="Mul" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23_1/ffn/layer_prepostprocess/layer_norm/mul_1_grad/Mul_1"}
  %constant.11228 = f32[] constant(0), metadata={op_type="RandomUniform" op_name="transformer/parallel_0_5/transformer/transformer/body/dropout/random_uniform/RandomUniform"}
  ROOT %reduce.1954 = f32[1024]{0} reduce(f32[4,49,1024]{2,1,0} %multiply.14985, f32[] %constant.11228), dimensions={0,1}, to_apply=%training_gradients_transformer_parallel_0_5_transformer_transformer_body_decoder_layer_23_1_ffn_layer_prepostprocess_layer_norm_mul_1_grad_Sum_1-reduction.48850, metadata={op_type="Sum" op_name="training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23_1/ffn/layer_prepostprocess/layer_norm/mul_1_grad/Sum_1"}
}

返回Tokens和AST结果为:

[%fused_computation.19.clone   ( param_0.16672 :   f32 [ 4 , 49 , 1024 ] ,   param_1.23221 :   f32 [ 196 , 1024 ] )   ->   f32 [ 1024 ]   {
     %param_1.23221   =   f32 [ 196 , 1024 ] { 1 , 0 }   parameter ( 1 )
     %reshape.13330   =   f32 [ 4 , 49 , 1024 ] { 2 , 1 , 0 }   reshape ( f32 [ 196 , 1024 ] { 1 , 0 }   %param_1.23221 ) ,   metadata = { op_type = "Reshape"   op_name = "training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23_1/ffn/conv1/Tensordot/Reshape_grad/Reshape" }
     %param_0.16672   =   f32 [ 4 , 49 , 1024 ] { 2 , 1 , 0 }   parameter ( 0 )
     %multiply.14985   =   f32 [ 4 , 49 , 1024 ] { 2 , 1 , 0 }   multiply ( f32 [ 4 , 49 , 1024 ] { 2 , 1 , 0 }   %reshape.13330 ,   f32 [ 4 , 49 , 1024 ] { 2 , 1 , 0 }   %param_0.16672 ) ,   metadata = { op_type = "Mul"   op_name = "training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23_1/ffn/layer_prepostprocess/layer_norm/mul_1_grad/Mul_1" }
     %constant.11228   =   f32 [ ]   constant ( 0 ) ,   metadata = { op_type = "RandomUniform"   op_name = "transformer/parallel_0_5/transformer/transformer/body/dropout/random_uniform/RandomUniform" }
     ROOT   %reduce.1954   =   f32 [ 1024 ] { 0 }   reduce ( f32 [ 4 , 49 , 1024 ] { 2 , 1 , 0 }   %multiply.14985 ,   f32 [ ]   %constant.11228 ) ,   dimensions = { 0 , 1 } ,   to_apply = %training_gradients_transformer_parallel_0_5_transformer_transformer_body_decoder_layer_23_1_ffn_layer_prepostprocess_layer_norm_mul_1_grad_Sum_1-reduction.48850 ,   metadata = { op_type = "Sum"   op_name = "training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23_1/ffn/layer_prepostprocess/layer_norm/mul_1_grad/Sum_1" }
 } <EOF>]
&main.HLO{
  Functions: []*main.HLOFunction{
    &main.HLOFunction{
      Name: "%fused_computation.19.clone",
      Params: []*main.Param{
        &main.Param{
          Name: "param_0.16672",
          Type: &main.Type{
            DataType: "f32",
            Dimensions: []int{
              4,
              49,
              1024,
            },
          },
        },
        &main.Param{
          Name: "param_1.23221",
          Type: &main.Type{
            DataType: "f32",
            Dimensions: []int{
              196,
              1024,
            },
          },
        },
      },
      ReturnTypes: []*main.Type{
        &main.Type{
          DataType: "f32",
          Dimensions: []int{
            1024,
          },
        },
      },
      Body: []*main.Instruction{
        &main.Instruction{
          VarName: "%param_1.23221",
          Fn: &main.FunctionCall{
            ReturnTypes: []*main.RichType{
              &main.RichType{
                VarType: "f32",
                VarDim: []int{
                  196,
                  1024,
                  1,
                  0,
                },
              },
            },
            Name: "parameter",
            Params: []*main.RichParam{
              &main.RichParam{
                Name: "1",
              },
            },
          },
        },
        &main.Instruction{
          VarName: "%reshape.13330",
          Fn: &main.FunctionCall{
            ReturnTypes: []*main.RichType{
              &main.RichType{
                VarType: "f32",
                VarDim: []int{
                  4,
                  49,
                  1024,
                  2,
                  1,
                  0,
                },
              },
            },
            Name: "reshape",
            Params: []*main.RichParam{
              &main.RichParam{
                Type: &main.RichType{
                  VarType: "f32",
                  VarDim: []int{
                    196,
                    1024,
                    1,
                    0,
                  },
                },
                Name: "%param_1.23221",
              },
            },
          },
          Meta: []*main.Meta{
            &main.Meta{
              Key: "metadata",
              DictValue: []*main.Dict{
                &main.Dict{
                  Key: “op_type”,
                  Value: “\”Reshape\””,
                },
                &main.Dict{
                  Key: "op_name”,
                  Value: “\”training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23_1/ffn/conv1/Tensordot/Reshape_grad/Reshape\””,
                },
              },
            },
          },
        },
        &main.Instruction{
          VarName: “%param_0.16672”,
          Fn: &main.FunctionCall{
            ReturnTypes: []*main.RichType{
              &main.RichType{
                VarType: "f32",
                VarDim: []int{
                  4,
                  49,
                  1024,
                  2,
                  1,
                  0,
                },
              },
            },
            Name: "parameter",
            Params: []*main.RichParam{
              &main.RichParam{
                Name: "0",
              },
            },
          },
        },
        &main.Instruction{
          VarName: "%multiply.14985",
          Fn: &main.FunctionCall{
            ReturnTypes: []*main.RichType{
              &main.RichType{
                VarType: "f32",
                VarDim: []int{
                  4,
                  49,
                  1024,
                  2,
                  1,
                  0,
                },
              },
            },
            Name: "multiply",
            Params: []*main.RichParam{
              &main.RichParam{
                Type: &main.RichType{
                  VarType: "f32",
                  VarDim: []int{
                    4,
                    49,
                    1024,
                    2,
                    1,
                    0,
                  },
                },
                Name: "%reshape.13330",
              },
              &main.RichParam{
                Type: &main.RichType{
                  VarType: "f32",
                  VarDim: []int{
                    4,
                    49,
                    1024,
                    2,
                    1,
                    0,
                  },
                },
                Name: "%param_0.16672",
              },
            },
          },
          Meta: []*main.Meta{
            &main.Meta{
              Key: "metadata",
              DictValue: []*main.Dict{
                &main.Dict{
                  Key: "op_type",
                  Value: "\"Mul\"",
                },
                &main.Dict{
                  Key: “op_name”,
                  Value: "\"training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23_1/ffn/layer_prepostprocess/layer_norm/mul_1_grad/Mul_1\"",
                },
              },
            },
          },
        },
        &main.Instruction{
          VarName: "%constant.11228",
          Fn: &main.FunctionCall{
            ReturnTypes: []*main.RichType{
              &main.RichType{
                VarType: "f32",
              },
            },
            Name: "constant",
            Params: []*main.RichParam{
              &main.RichParam{
                Name: "0",
              },
            },
          },
          Meta: []*main.Meta{
            &main.Meta{
              Key: "metadata",
              DictValue: []*main.Dict{
                &main.Dict{
                  Key: "op_type",
                  Value: "\"RandomUniform\"",
                },
                &main.Dict{
                  Key: "op_name",
                  Value: "\"transformer/parallel_0_5/transformer/transformer/body/dropout/random_uniform/RandomUniform\"",
                },
              },
            },
          },
        },
        &main.Instruction{
          VarName: “%reduce.1954”,
          Fn: &main.FunctionCall{
            ReturnTypes: []*main.RichType{
              &main.RichType{
                VarType: "f32",
                VarDim: []int{
                  1024,
                  0,
                },
              },
            },
            Name: "reduce",
            Params: []*main.RichParam{
              &main.RichParam{
                Type: &main.RichType{
                  VarType: “f32”,
                  VarDim: []int{
                    4,
                    49,
                    1024,
                    2,
                    1,
                    0,
                  },
                },
                Name: “%multiply.14985”,
              },
              &main.RichParam{
                Type: &main.RichType{
                  VarType: "f32",
                },
                Name: "%constant.11228",
              },
            },
          },
          Meta: []*main.Meta{
            &main.Meta{
              Key: "dimensions",
              ListNums: []int{
                0,
                1,
              },
            },
            &main.Meta{
              Key: "to_apply",
              Value: "%training_gradients_transformer_parallel_0_5_transformer_transformer_body_decoder_layer_23_1_ffn_layer_prepostprocess_layer_norm_mul_1_grad_Sum_1-reduction.48850",
            },
            &main.Meta{
              Key: "metadata",
              DictValue: []*main.Dict{
                &main.Dict{
                  Key: "op_type",
                  Value: "\"Sum\"",
                },
                &main.Dict{
                  Key: "op_name",
                  Value: "\"training/gradients/transformer/parallel_0_5/transformer/transformer/body/decoder/layer_23_1/ffn/layer_prepostprocess/layer_norm/mul_1_grad/Sum_1\"",
                },
              },
            },
          },
        },
      },
    },
  },
}
{Functions:[0xc0000f03c0]}


You'll only receive email when 荣懿的草稿本 publishes a new post

More from 荣懿的草稿本