XLA之HloModule“合并”

发布于:2024-04-26 ⋅ 阅读:(16) ⋅ 点赞:(0)

项目里有这样的一个需求,后端要求将特定模式的的算子融合为CustomCall传下去。

例如下面的HLO计算图,需要将省略部分之间的算子融合为CustomCall。

分析

图合并的背景

前端有个X86的后端模拟器(后面简称模拟器),需要将前端的图跑在模拟器上,来验证真正芯片后端的准确性。
在上面的融合需求提出之后,带有CustomCall的图显然不能够直接运行在模拟器上,除非针对此CustomCall增加特定的展开pass,这样就做了很多重复性的工作,而且如果再有类似的CustomCall需求,同样会面临这个问题。

融合之前

// 主图
HloModule MainGraph, entry_computation_layout=...

%AddComputation.14 (x.15: f32[], y.16: f32[]) -> f32[] {
  %x.15 = f32[] parameter(0)
  %y.16 = f32[] parameter(1)
  ROOT %add.17 = f32[] add(f32[] %x.15, f32[] %y.16)
}

// 主图的entry_computation
ENTRY %MainGraph ... {
  // 省略...
  %p0.2 = f32[10]{0} parameter(4)
  %p1.10 = f32[10]{0} parameter(7)
  %add = f32[10]{0} add(f32[10]{0} %p0.2, f32[10]{0} %p1.10)
  %constant.12 = f32[] constant(0)
  %constant.13 = f32[10,5]{1,0} constant(1)
  %reduce.18 = f32[10]{0} reduce(f32[10,5]{1,0} %constant.13, f32[] %constant.12), dimensions={1}, to_apply=%AddComputation.14
  %multiply.40 = f32[10]{0} multiply(f32[10]{0} %add, f32[10]{0} %reduce.18)
  // 省略...
}

融合之后

融合为CustomCall之后,会得到如下的形式:

// 主图
HloModule SyncTensorsGraph.10, entry_computation_layout=...

// 主图的entry_computation
ENTRY %SyncTensorsGraph.10 ... {
  // 省略...
  %custom=f32[10]{0} custom-call(%constant.89, %constant.92), custom_call_target="MyCustomCall", backend_config=""
  // 省略...
}

思路

在前端融合CustomCall时,将融合的多个算子抽象为一个HloModule,并作为字符串从CustomCall的属性传出去,这里利用backend_config属性。例如,将上述算子抽象成如下HloModule,其中,HloModule的输入参数为CustomCall的operands,HloModule的返回值为CustomCall的返回值:

HloModule SyncTensorsGraph.42, entry_computation_layout={(f32[10]{0},f32[10]{0})->f32[10]{0}}

%AddComputation.14 (x.15: f32[], y.16: f32[]) -> f32[] {
  %x.15 = f32[] parameter(0)
  %y.16 = f32[] parameter(1)
  ROOT %add.17 = f32[] add(f32[] %x.15, f32[] %y.16)
}

ENTRY %SyncTensorsGraph.42 (p0.2: f32[10], p1.10: f32[10]) -> f32[10] {
  %p0.2 = f32[10]{0} parameter(0)
  %p1.10 = f32[10]{0} parameter(1)
  %add = f32[10]{0} add(f32[10]{0} %p0.2, f32[10]{0} %p1.10)
  %constant.12 = f32[] constant(0)
  %constant.13 = f32[10,5]{1,0} constant(1)
  %reduce.18 = f32[10]{0} reduce(f32[10,5]{1,0} %constant.13, f32[] %constant.12), dimensions={1}, to_apply=%AddComputation.14
  ROOT %multiply.40 = f32[10]{0} multiply(f32[10]{0} %add, f32[10]{0} %reduce.18)
}

合并前的图

体现在主图中,如下:

HloModule Test

ENTRY Test  {
  %p0.1 = f32[10]{0} parameter(0)
  %constant.89 = f32[10]{0} constant(1)
  %constant.92 = f32[10]{0} constant(2)
  %custom = f32[10]{0} custom-call(%constant.89, %constant.92), custom_call_target="MyCustomCall", backend_config="
    HloModule SyncTensorsGraph.42, entry_computation_layout={(f32[10]{0},f32[10]{0})->f32[10]{0}}

    %AddComputation.14 (x.15: f32[], y.16: f32[]) -> f32[] {
      %x.15 = f32[] parameter(0)
      %y.16 = f32[] parameter(1)
      ROOT %add.17 = f32[] add(f32[] %x.15, f32[] %y.16)
    }

    ENTRY %SyncTensorsGraph.42 (p0.2: f32[10], p1.10: f32[10]) -> f32[10] {
      %p0.2 = f32[10]{0} parameter(0)
      %p1.10 = f32[10]{0} parameter(1)
      %add = f32[10]{0} add(f32[10]{0} %p0.2, f32[10]{0} %p1.10)
      %constant.12 = f32[] constant(0)
      %constant.13 = f32[10,5]{1,0} constant(1)
      %reduce.18 = f32[10]{0} reduce(f32[10,5]{1,0} %constant.13, f32[] %constant.12), dimensions={1}, to_apply=%AddComputation.14
      ROOT %multiply.40 = f32[10]{0} multiply(f32[10]{0} %add, f32[10]{0} %reduce.18)
    }
  "
  ROOT add = f32[10]{0} add(%custom, %constant.92)
}

合并后的图

注意观察:

  • 原图中的CustomCall被重新展开为多个HloInstruction的组合
  • 原图的“子图”中的非entry_computation已经被clone到主图中,并正常调用
HloModule Test, entry_computation_layout={(f32[10]{0})->f32[10]{0}}

%AddComputation.14.clone (x.15: f32[], y.16: f32[]) -> f32[] {
  %x.15 = f32[] parameter(0)
  %y.16 = f32[] parameter(1)
  ROOT %add.17 = f32[] add(f32[] %x.15, f32[] %y.16)
}

ENTRY %Test (p0.1: f32[10]) -> f32[10] {
  %p0.1 = f32[10]{0} parameter(0)
  %constant.89 = f32[10]{0} constant({1, 0, 0, 0, 0, 0, 0, 0, 0, 0})
  %constant.92 = f32[10]{0} constant({2, 0, 0, 0, 0, 0, 0, 0, 0, 0})
  %add.1 = f32[10]{0} add(f32[10]{0} %constant.89, f32[10]{0} %constant.92)
  %constant.1 = f32[10,5]{1,0} constant({...})
  %constant = f32[] constant(0)
  %reduce = f32[10]{0} reduce(f32[10,5]{1,0} %constant.1, f32[] %constant), dimensions={1}, to_apply=%AddComputation.14.clone
  %multiply = f32[10]{0} multiply(f32[10]{0} %add.1, f32[10]{0} %reduce)
  ROOT %add = f32[10]{0} add(f32[10]{0} %multiply, f32[10]{0} %constant.92)
}

编码实现

  1. 整体算法,利用递归实现,启发于复制二叉树算法,这里实际上是复制图
  2. 由于子图的root instruction和非root instruction需要分别处理,root instruction可以直接作为unique_ptr被替换到主图,这里使用std::variant进行处理,root返回std::unique_ptr<HloInstruction>,非root返回HloInstruction*
  3. 在利用HloInstructionPtrHloInstruction*进行全局替换时,出现问题,CopyGraph的第二个参数,体会到《C++ Templates - The Complete Guide, 2nd Edition》中 Some Remarks About Programming Style 一节的描述:使用const时,和类型的书写顺序,为何要这样写
  4. 不同类型的HloInstruction基本覆盖了:普通Binary算子,Constant,parameters,reduce(含有子图调用)
#include "tensorflow/compiler/plugin/ipu/driver/passes/ipu_custom_call_expander_pass.h"

#include <algorithm>
#include <functional>
#include <iterator>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/lib/statusor.h"
// 利用HloParser将字符串解析为HloModule
// bazel BUILD文件中deps,需要增加对应依赖:"//tensorflow/compiler/xla/service:hlo_parser"
#include "tensorflow/compiler/xla/service/hlo_parser.h"

namespace xla {
namespace {

using HloInstructionUniquePtr = std::unique_ptr<HloInstruction>;
using HloInstructionPtr = HloInstruction*;

std::variant<HloInstructionPtr, HloInstructionUniquePtr>
CopyGraph(HloComputation* parent,
          HloInstruction const* root,
          const HloInstruction::InstructionVector& cc_operands,
          std::vector<HloComputation*> *const non_entry_comps,
          bool is_root_instr=false){

  auto opcode = root->opcode();
  VLOG(2) << "opcode " << opcode;

  HloInstructionUniquePtr cloned_instr;
  switch (opcode)
  {
  case HloOpcode::kMultiply:
  case HloOpcode::kAdd:
    {
      cloned_instr = std::move(HloInstruction::CreateBinary(
                                    root->shape(), 
                                    opcode, 
                                    std::get<HloInstructionPtr>(CopyGraph(parent, root->operand(0), cc_operands, non_entry_comps)), 
                                    std::get<HloInstructionPtr>(CopyGraph(parent, root->operand(1), cc_operands, non_entry_comps))
                            ));
      // HloInstruction::HloInstruction has protected access, can not be constructed here
      // new_instr = new HloInstruction(opcode, root->shape());
      // new_instr->AppendOperand(const_cast<HloInstructionPtr>(root->operand(0)));
      // new_instr->AppendOperand(const_cast<HloInstructionPtr>(root->operand(1)));
    }
    break;
  // 图的输入参数,其实就是CustomCall的operands,且按顺序一一对应
  case HloOpcode::kParameter:
    return cc_operands[root->parameter_number()];
  case HloOpcode::kConstant:
    {
      auto& literal = const_cast<Literal&>(root->literal());
      cloned_instr = std::move(HloInstruction::CreateConstant(std::move(literal)));
      break;
    }
  case HloOpcode::kReduce:
    {

      // tensorflow/compiler/xla/service/hlo_instruction.cc:1569#HloInstruction::CreateReduce
      // todo: 多个operand/init_value情况
      auto* root_nc = const_cast<HloInstructionPtr>(root);
      auto* reduce = dynamic_cast<HloReduceInstruction*>(root_nc);
      HloComputation* comp = nullptr;
      VLOG(1) << non_entry_comps->size();
      for(auto *cp: *non_entry_comps){
        // 取出clone的同名HloComputation
        if(cp->name() == reduce->to_apply()->name() + ".clone"){
          VLOG(3) << cp->name() << " vs " << reduce->to_apply()->name();
          comp = cp;
        }
      }
      VLOG(1) << comp->ToString();
      cloned_instr = std::move(HloInstruction::CreateReduce(
                                    root->shape(), 
                                    std::get<HloInstructionPtr>(CopyGraph(parent, root->operand(0), cc_operands, non_entry_comps)),
                                    std::get<HloInstructionPtr>(CopyGraph(parent, root->operand(1), cc_operands, non_entry_comps)),
                                    //operand(1) and init_values()[0] both are OK
                                    // std::get<HloInstructionPtr>(CopyGraph(parent, reduce->init_values()[0], cc_operands, non_entry_comps)),
                                    reduce->dimensions(), 
                                    comp
                            ));
      break;
    }
  default:
    VLOG(1) << "unsupported opcode: " << opcode;
    break;
  }
  VLOG(1) << cloned_instr->parent() << ", " << cloned_instr->ToString();
  CHECK(cloned_instr && "cloned_instr is not assigned");
  if(!is_root_instr){
  	// fetch first, empty when moved
    auto raw = cloned_instr.get();
    // 需要添加到HloComputation
    parent->AddInstruction(std::move(cloned_instr));
    return raw;
  }
  return std::move(cloned_instr);
}

}  // namespace

namespace ipu {

StatusOr<bool> IpuCustomCallExpanderPass::Run(
    HloModule* module,
    const absl::flat_hash_set<absl::string_view>& execution_threads) {
  XLA_VLOG_LINES(2, "IpuCustomCallExpanderPass::Run(), before:\n" + module->ToString());
  bool changed = false;

  for (auto* computation : module->MakeNonfusionComputations()) {
    for (auto instruction : computation->MakeInstructionPostOrder()) {
      if (instruction->IsDead()) {
        continue;
      }

      if(HloOpcode::kParameter == instruction->opcode()){
        VLOG(1) << "instruction->parameter_number() " << instruction->parameter_number();
      }
      if(HloOpcode::kCustomCall == instruction->opcode()){
        HloCustomCallInstruction* cc = xla::Cast<xla::HloCustomCallInstruction>(instruction);


        // custom-call backend_config to HloModule
        auto sm = xla::ParseAndReturnUnverifiedModule(cc->opaque());
        auto cm = std::move(sm.value());

        HloComputation* target_comp = nullptr;
		// per CustomCall
  		std::vector<HloComputation*> non_entry_comps;
        // collect all non entry computations and clone all non entry computations to outer module
        for (auto* comp : cm->MakeNonfusionComputations()) {
          if(!comp->IsEntryComputation()){
            auto cp = comp->Clone();
            non_entry_comps.push_back(cp.get());
            module->AddEmbeddedComputation(std::move(cp));
          }else{
          	target_comp = comp;
		  }
        }

        // HloComputation has the same params as custom-call, use them to replace its parameters
        auto new_root = std::get<HloInstructionUniquePtr>(CopyGraph(computation, target_comp->root_instruction(), cc->operands(), &non_entry_comps, true));
        TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(instruction, std::move(new_root)));
        changed = true;
      }
    }
  }
  XLA_VLOG_LINES(2, "IpuCustomCallExpanderPass::Run(), after:\n" + module->ToString());
  return changed;
}

}  // namespace ipu
}  // namespace xla


todo

逐算子支持,当前仅支持了几个典型算子;