[Python source code reading] - torch trace. JIT interface C + + source code reading

preface

This paper is mainly from torch jit. Trace interface, open the door of Python JIT, and introduce it in normal NN The process of module forming scriptModule after trace and the C + + classes involved. Because there are a lot of contents involved, here are some relatively important classes from the perspective of source code or that can help us understand.

Before starting to study, I wonder if you have thought about a seemingly simple question that may not be so easy to answer. Why should compiling languages be divided into data types?

In fact, the general meaning of search is that there will be different computing circuits in the computer in order to realize different functions. For these different circuits, the upper layer corresponds to different data types. Therefore, the most basic thing for contact to become a language is to understand its data types.

torch::jit::Type

First and foremost is the type. Therefore, in combing the jit code, it finally falls on the torch::jit::Type class. This base class represents different types. At present, there are 33 types that can be represented by types, including the types in IValue's Tag.

// torch/include/ATen/core/jit_type_base.h 
#define C10_FORALL_TYPES(_) \
  _(AnyType)                \
  _(EnumType)               \
  _(AnyEnumType)            \
  _(TensorType)             \
  _(StorageType)            \
  _(TupleType)              \
  _(ListType)               \
  _(DictType)               \
  _(NumberType)             \
  _(FloatType)              \
  _(ComplexType)      \
  _(FutureType)             \
  _(RRefType)               \
  _(IntType)                \
  _(NoneType)               \
  _(StringType)             \
  _(GeneratorType)          \
  _(QuantizerType)          \
  _(BoolType)               \
  _(OptionalType)           \
  _(VarType)                \
  _(DeviceObjType)          \
  _(StreamObjType)          \
  _(FunctionType)           \
  _(ClassType)              \
  _(PyObjectType)           \
  _(CapsuleType)            \
  _(InterfaceType)          \
  _(QSchemeType)            \
  _(LayoutType)             \
  _(ScalarTypeType)         \
  _(AnyListType)            \
  _(AnyTupleType)           \
  _(AnyClassType)

enum class TypeKind {
#define DEFINE_TYPE(T) T,
  C10_FORALL_TYPES(DEFINE_TYPE)
#undef DEFINE_TYPE
};

Type type mainly defines the TypeKind enumeration above, and then you can query the relevant print name, whether it is SubType, whether it is module, and various cast functions, and then get the type in the array. The following is part of the source code:

struct TORCH_API Type : std::enable_shared_from_this<Type> {
 private:
  TypeKind kind_;
 protected:
  Type(TypeKind kind) : kind_(kind) {}
 public:
  virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
   virtual std::string str() const = 0;
  template <typename T>
  std::shared_ptr<T> cast() {
    if (T::Kind == kind()) {
      return std::static_pointer_cast<T>(shared_from_this());
    }
    return nullptr;
  }
  
  virtual at::ArrayRef<TypePtr> containedTypes() const {
    return {};
  }
};

ClassType

In pytorch / pytorch / Torch / include / aten / core / JIT_ Type. Under path h, various subclasses of Type are implemented. The more important ones here are TensorType and ClassType. In view of the previous introduction to Tensor, I won't expand here, focusing on ClassType.

An auxiliary class ClassAttribute is defined in the class. The class attribute ClassAttribute mainly revolves around: name, kind(AttributeKind), Type(TypePtr), so the following class definition is provided:

struct TORCH_API ClassAttribute {
  public:
  ClassAttribute(AttributeKind kind,
  TypePtr attributeType,
  std::string attributeName) :
    kind_(kind),
    attributeType_(attributeType),
    attributeName_(std::move(attributeName)) {}

  AttributeKind getKind() const {
    return kind_;
  }

  TypePtr getType() const {
    return attributeType_;
  }

  const std::string& getName() const {
    return attributeName_;
  }

  private:
  AttributeKind kind_;
  TypePtr attributeType_;
  std::string attributeName_;
};

Of which:

enum class AttributeKind {
  BUFFER,
  PARAMETER,
  REGULAR_ATTRIBUTE
};

The class definition of classType is also very long. The related contents are listed below:

struct TORCH_API ClassType : public NamedType {
  // Create a class type with name `name` and its methods stored in `cu`.
  static ClassTypePtr create(
      c10::optional<QualifiedName> qualifiedName,
      std::weak_ptr<CompilationUnit> cu,
      bool is_module = false,
      std::string doc_string = "",
      std::vector<std::string> unresolved_class_attributes = {});
  
  const std::vector<torch::jit::Function*>& methods() const;
  std::string str() const override
  bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
  ...
  //——————————————————————Operation of property————————————————————————
  TypePtr findAttribute(const std::string& name) const;
  TypePtr getAttribute(const std::string& name) const;
  size_t numAttributes() const;
  size_t addOrCheckAttribute(
      const std::string& name,
      TypePtr ty,
      bool is_parameter = false,
      bool is_buffer = false) 
  ...
 //——————————————————————Operation of constant node————————————————————————
  size_t addConstant(const std::string& name, const IValue& value);
  IValue getConstant(const std::string& name) const;
  at::ArrayRef<IValue> constantValues() const;
  ...
 //——————————————————————Function related operations————————————————————————
  void addForwardPreHook(torch::jit::Function* pre_hook_ptr);
  void addForwardHook(torch::jit::Function* hook_ptr);
  const std::vector<torch::jit::Function*>& getForwardHooks() const;
  const std::vector<torch::jit::Function*>& getForwardPreHooks() const;
  void addMethod(torch::jit::Function* method);
  torch::jit::Function* findMethod(const std::string& name) const;
  torch::jit::Function& getMethod(const std::string& name) const;
  torch::jit::Function* findHook(const std::string& name) const;
  private:
  std::vector<std::string> constantNames_;
  std::vector<IValue> constantValues_;
  
  std::vector<ClassAttribute> attributes_;
  std::vector<TypePtr> attributeTypes_;
  
  std::vector<torch::jit::Function*> methods_;
  std::vector<torch::jit::Function*> staticmethods_;
  
  std::vector<torch::jit::Function*> forward_hooks_;
  std::vector<torch::jit::Function*> forward_pre_hooks_;
  ...
};

The QualifiedName class is used to represent the shape, such as foo bar. Baz is a name in this format. The CompilationUnit class can be regarded as a List of named functions, which stores the functions of the class and provides relevant interfaces to traverse and call relevant functions. You can see that a ClassType is similar to a real class. In addition to the common methods of the inherited Type class, the main member functions focus on the attributes (ClassAttribute), the methods (Method, hook, prehook) of the class, and the related operations of the constant node. Finally, a compilation_unit is required during construction.

c10::ivalue::Object

The type ClassType representing the class is defined earlier. With the class type, you can create objects. The first thing to introduce is the definition in Ivalue_ inl. c10::ivalue::Object in H is not very long. The following class definitions are posted first:

// torch/include/ATen/core/ivalue_inl.h
struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
 public:
  Object(StrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
    slots_.resize(numSlots);
  }

  static c10::intrusive_ptr<Object> create(
      StrongTypePtr type,
      size_t numSlots) {
    return c10::make_intrusive<Object>(std::move(type), numSlots);
  }
  // slot related operations are vector < Ivalue >
  void setSlot(size_t slot, IValue v) {
    if (slot >= slots_.size()) {
      resizeObject(slot);
    }
    slots_[slot] = std::move(v);
  }
  
  const IValue& getSlot(size_t slot) const {
       return slots_[slot];
  }
  void unsafeRemoveSlot(size_t slot);
  // Attribute is also some ivalues so that you can access the attributes of the module
  IValue getAttr(const std::string& name) const;
  void setAttr(const std::string& name, IValue v);
  void unsafeRemoveAttr(const std::string& name);

  std::string name() const;

  const std::vector<IValue>& slots() const {
    return slots_;
  }
  std::shared_ptr<ClassType> type() const;
  // Depth copy function
  c10::intrusive_ptr<Object> copy() const;
  c10::intrusive_ptr<Object> deepcopy() const;

 private:
  void resizeObject(size_t slot);
  StrongTypePtr type_;
  std::vector<IValue> slots_;
};

// Some of the above member functions are implemented in: pytorch / pytorch / aten / SRC / aten / core / Ivalue CPP, interested readers can read by themselves.

Where StrongTypePtr means:

// torch/include/ATen/core/ivalue.h
struct TORCH_API StrongTypePtr {
  StrongTypePtr(
      std::shared_ptr<torch::jit::CompilationUnit> cu,
      std::shared_ptr<Type> type);

  std::shared_ptr<torch::jit::CompilationUnit> cu_;
  std::shared_ptr<Type> type_;
};

You can see that the ivalue::Object class mainly defines a slot of vector < Ivalue >_ Variable, which is used to define the Attribute of module (or ClassType). The bottom layer is slot_, and the first layer of slot package is attr. Then the external module adds parameters or buffers, and the bottom layer calls addAttribute, which indicates whether it is parameter / buffer / other attributes through flag bit.

Then there is the torch::jit::Object class.

torch::jit::Object

// pytorch/pytorch/torch/csrc/jit/api/object.h

using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;

struct TORCH_API Object {
  Object() = default;
  Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {}
  Object(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
  Object(
      c10::QualifiedName,
      std::shared_ptr<CompilationUnit> cu,
      bool shouldMangle = false);
  ObjectPtr _ivalue() const;

  c10::ClassTypePtr type() const {
    return _ivalue()->type();
  }
  
  void setattr(const std::string& name, c10::IValue v) {
    if (_ivalue()->type()->hasConstant(name)) {
      TORCH_CHECK(
          false,
          "Can't set constant '",
          name,
          "' which has value:",
          _ivalue()->type()->getConstant(name));
    } else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) {
      const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot);
      TORCH_CHECK(
          v.type()->isSubtypeOf(expected),
          "Expected a value of type '",
          expected->repr_str(),
          "' for field '",
          name,
          "', but found '",
          v.type()->repr_str(),
          "'");
      _ivalue()->setSlot(*slot, std::move(v));
    } else {
      TORCH_CHECK(false, "Module has no attribute '", name, "'");
    }
  }
  
  
c10::IValue attr(const std::string& name) const {
    if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
      return _ivalue()->getSlot(*r);
    }
    if (auto r = _ivalue()->type()->findConstantSlot(name)) {
      return _ivalue()->type()->getConstant(*r);
    }
    std::stringstream err;
    err << _ivalue()->type()->repr_str() << " does not have a field with name '"
        << name.c_str() << "'";
    throw ObjectAttributeError(err.str());
  }

c10::IValue attr(const std::string& name, c10::IValue or_else) const {
    if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
      return _ivalue()->getSlot(*r);
    }
    if (auto r = _ivalue()->type()->findConstantSlot(name)) {
      return _ivalue()->type()->getConstant(*r);
    }
    return or_else;
  }

bool hasattr(const std::string& name) const {
    return _ivalue()->type()->hasAttribute(name) ||
        _ivalue()->type()->hasConstant(name);
  }

  // Each object has its own method
  Method get_method(const std::string& name) const {
    if (auto method = find_method(name)) {
      return *method;
    }
    AT_ERROR("Method '", name, "' is not defined.");
  }

  const std::vector<Method> get_methods() const {
    return c10::fmap(type()->methods(), [&](Function* func) {
      return Method(_ivalue(), func);
    });
  }
  
  c10::optional<Method> find_method(const std::string& basename) const;
  
  template <typename... Types>
  IValue run_method(const std::string& method_name, Types&&... args) {
    return get_method(method_name)({IValue(std::forward<Types>(args))...});
  }
  
  void define(const std::string& src, const ResolverPtr& resolver = nullptr);

  size_t num_slots() const {
    return _ivalue()->slots().size();
  }
  // Copy
  Object copy() const;
  Object deepcopy() const;
  
  private:
  mutable ObjectPtr _ivalue_;
};

The torch::jit::Object class is extended on the basis of c10::ivalue::object. For example, the related copy implementation is the c10::ivalue::object implementation directly called.

You can see the constructor of torch::jit::Object class here, which is mainly constructed by pointers in ClassType, cu and ClassTypePtr. As can be seen from the above class definition, torch::jit::Object class mainly comes from the implementation of c10::ivalue::object, and relevant operations mainly come from the ClassType class interface contained in c10::ivalue::object_ ivalue()->type() . Perform basic attr and method operations.

Inheritance relationship: ClassType::NamedType::Type. To sum up, the torch::jit::Object class mainly performs some attribute operations and Method operations through ivalue::Object + ClassType. Includes deep and shallow copies of Object objects. There is a cu pointer of CompilationUnit type, which is available when creating objects. In fact, it is a list containing Funciton.

torch::jit::Module

First of all, we should clarify the following concepts:

"Module" is an abstraction of the implementation of some functions or algorithms. A module class mainly has the following properties:

  • Buffers: tensor s that do not record gradients are generally updated in forward propagation, such as the mean and variance of BatchNorm operator.
  • Parameters: tensor s that record gradients, such as weights that are usually updated in back propagation.
  • Other states: not necessarily the tensor type, but the state quantity required in Module implementation or configuration.

It corresponds to AttributeKind in ClassType. Note that the Module clone will be copied in depth. Module provides register_parameter and register_buffer to register two different tensors respectively.

Module classes can be nested, that is, a module can have submodule s.

Based on the above understanding, let's go to the class of torch::jit::Module to explore related operations.

// pytorch/pytorch/torch/csrc/jit/api/module.h
struct TORCH_API Module : public Object {
  // A bunch of initialization functions, where the initialization function basically contains the initialization style of Object
  explicit Module(c10::QualifiedName class_name);
  Module(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
  Module() = default;
  Module(
      c10::QualifiedName,
      std::shared_ptr<CompilationUnit> cu,
      bool shouldMangle = false);
  Module(ModulePtr module_value) : Object(std::move(module_value)) {}
  ~Module() = default;

  // Main forward function
  IValue forward(std::vector<IValue> inputs) {
    return get_method("forward")(std::move(inputs));
  }
  
  // Registration of important elements of related modules
 void register_buffer(const std::string& name, at::Tensor v) {
    bool is_param = false;
    bool is_buffer = true;
    type()->addOrCheckAttribute(name, TensorType::get(), is_param, is_buffer);
    _ivalue()->setAttr(name, std::move(v));
  }

  void register_parameter(
      const std::string& name,
      at::Tensor v,
      bool is_buffer) {
    type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer, is_buffer);
    _ivalue()->setAttr(name, std::move(v));
  }
  
  // Recursive application correlation fn
  void apply(const std::function<void(Module&)>& fn);
  ...

  // Relevant actual content acquisition
  buffer_list buffers(bool recurse = true) const;
  named_buffer_list named_buffers(bool recurse = true) const;

  module_list children() const; // direct modules
  named_module_list named_children() const;
  module_list modules() const; // all modules, including this one, recursively
  named_module_list named_modules() const;

  // all tensors involved in gradient optimization
  parameter_list parameters(bool recurse = true) const;
  named_parameter_list named_parameters(bool recurse = true) const;

  // all members of the object, similar to iterating over dir(obj) in python
  attribute_list attributes(bool recurse = true) const;
  named_attribute_list named_attributes(bool recurse = true) const;
  
  // dump function
  void dump(
      bool print_method_bodies,
      bool print_attr_values,
      bool print_param_values) const;
   ...
     
  // It also contains other to(), save(), copy(), deepcopy(),clone(), clone()_ Method () and other member functions are not introduced here.
};

Here, when we go to torch jit. A NN is passed into trace Module, middle py will be converted to script module. The length of the incoming jit::Module printed by the dump interface is about this:

module __torch__.TestArangeModel {
  parameters {
  }
  attributes {
    training = True
    _is_full_backward_hook = None
    conv1 = <__torch__.torch.nn.modules.conv.Conv2d object at 0x557569c9f300>
    conv2 = <__torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d object at 0x557569cb4080>
  }
  methods {
  }
  submodules {
    module __torch__.torch.nn.modules.conv.Conv2d {
      parameters {
        weight = ...
      }
      attributes {
        weight = ...
        training = True
        _is_full_backward_hook = None
      }
      methods {
      }
      submodules {
      }
    } // The following qualified name is:__ torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d
    module __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d {
      parameters {
        weight = ...  // Both parameter and buffer are in parameters, which will be represented at the bottom
        bias = ...
      }
      attributes {
        weight = ... // When registering, parameter and buffer will also be added to attr together with other attributes
        bias = ...
        training = True
        _is_full_backward_hook = None
      }
      methods {   // method
      }
      submodules { // The module itself is nested
      }
    }
  }
}

torch.jit.trace is mainly a process of converting torch::jit::Module to torch::jit::Graph, focusing on a series of transformations of the Graph after converting to Graph structure, tracking the intermediate state through a middle class called trackstate, and finally returning a converted torch::jit::Graph.

Transformation process

The above describes the classes involved. With the above content, we will mainly explain JIT The general working process of trace interface in C + +.

It can be seen from the above that there are no methods in the Module at the beginning (it may be because it is transferred from nn.Module, and there is no method related content in its own class definition nn.Module). Next, we will explore step by step how to build a graph, including what the final Module looks like.

  1. First, all attribute s and inputs and outputs of the module object will be represented by nodes recursively.
graph(%self : __torch__.___torch_mangle_3.TestArangeModel,
      %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
  %1 : bool = prim::TracedAttr[scope="__module.training"]()
  %2 : NoneType = prim::TracedAttr[scope="__module._is_full_backward_hook"]()
  %3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::TracedAttr[scope="__module.conv1"]()
  %4 : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv1.weight"]()
  %5 : bool = prim::TracedAttr[scope="__module.conv1.training"]()
  %6 : NoneType = prim::TracedAttr[scope="__module.conv1._is_full_backward_hook"]()
  %7 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::TracedAttr[scope="__module.conv2"]()
  %8 : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.weight"]()
  %9 : Float(3, strides=[1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.bias"]()
  %10 : bool = prim::TracedAttr[scope="__module.conv2.training"]()
  %11 : NoneType = prim::TracedAttr[scope="__module.conv2._is_full_backward_hook"]()
  return ()

This is the effect after the above Module is converted to Graph.

After passing in the traced function, you will get:

graph(%self : __torch__.TestArangeModel,
      %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
  %1 : bool = prim::TracedAttr[scope="__module.training"]()
  %2 : NoneType = prim::TracedAttr[scope="__module._is_full_backward_hook"]()
  %3 : __torch__.torch.nn.modules.conv.Conv2d = prim::TracedAttr[scope="__module.conv1"]()
  %weight.1 : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv1.weight"]()
  %5 : bool = prim::TracedAttr[scope="__module.conv1.training"]()
  %6 : NoneType = prim::TracedAttr[scope="__module.conv1._is_full_backward_hook"]()
  %7 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d = prim::TracedAttr[scope="__module.conv2"]()
  %weight : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.weight"]()
  %bias : Float(3, strides=[1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.bias"]()
  %10 : bool = prim::TracedAttr[scope="__module.conv2.training"]()
  %11 : NoneType = prim::TracedAttr[scope="__module.conv2._is_full_backward_hook"]()
   = prim::TracedModuleForward[scope="__module.conv1"](), scope: __module.conv1
    block0():
      %13 : NoneType = prim::Constant(), scope: __module.conv1
      %14 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %15 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %16 : int[] = prim::ListConstruct(%14, %15), scope: __module.conv1
      %17 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %18 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %19 : int[] = prim::ListConstruct(%17, %18), scope: __module.conv1
      %20 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %21 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %22 : int[] = prim::ListConstruct(%20, %21), scope: __module.conv1
      %23 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %24 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %25 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %26 : int[] = prim::ListConstruct(%24, %25), scope: __module.conv1
      %27 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %28 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %29 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %30 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %31 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=1, device=cpu) = aten::_convolution(%x, %weight.1, %13, %16, %19, %22, %23, %26, %27, %28, %29, %30, %31), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      -> ()
   = prim::TracedModuleForward[scope="__module.conv2"](), scope: __module.conv2
    block0():
      %33 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %34 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %35 : int[] = prim::ListConstruct(%33, %34), scope: __module.conv2
      %36 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %37 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %38 : int[] = prim::ListConstruct(%36, %37), scope: __module.conv2
      %39 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %40 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %41 : int[] = prim::ListConstruct(%39, %40), scope: __module.conv2
      %42 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %43 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %44 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %45 : int[] = prim::ListConstruct(%43, %44), scope: __module.conv2
      %46 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %47 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %48 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %49 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %50 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %51 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=1, device=cpu) = aten::_convolution(%input, %weight, %bias, %35, %38, %41, %42, %45, %46, %47, %48, %49, %50), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      -> ()
  return (%51)

The next step is to go through various passes to further process the current graph. To enter pytorch / pytorch / Torch / CSR / JIT / passes / fixup_ trace_ scope_ blocks. Related optimized pass functions in CPP:

void FixupTraceScopeBlocks(std::shared_ptr<Graph>& graph, Module* self) {
  if (self) {
    ConvertTracedAttrReferences().run(graph);
  } else {
    for (Node* n : graph->nodes()) {
      TORCH_INTERNAL_ASSERT(n->kind() != prim::TracedAttr);
    }
  }
  MakeDefsDominateUses().run(graph->block());
  convertReturnsToTuples(graph->block());
  if (!self) {
    // We have no Module, so we're just going to inline everything.
    // This should give us a totally flat graph.
    inlineScopeBlocks(graph->block());
    // For TracedFork nodes
    lambdaLiftBlocksAndConvertToGraph(graph->block());
    runCleanupPasses(graph);
  } else {
    lambdaLiftBlocksAndConvertToGraph(graph->block());
    createMethodCalls(graph);
    runCleanupPasses(self);
    // `graph` isn't referenced in `self` yet, so we need to run
    // this separately
    runCleanupPasses(graph);
  }
}

The main idea of the above function is to deal with the relationship between nodes in the diagram, and then deal with prim::Tracedxxx nodes. The following steps are roughly done:

  1. First, recursively find the prim::TracedModuleForward node, then find the prim:: tracedatr node with the same scope, obtain the relevant sub module node from the input module, then pass in the prim::TracedModuleForward type as a parameter, and add the input information of block0.

  2. Then add the necessary prim::GetAttr to replace the active prim:: tracedatr node and replace the corresponding output

  3. Delete the useless prim::TracedAttr node

  4. Give TracedModuleForward a direct = sign, and add input / output name, type and other information

  5. Processing multiple outputs to a tuple

  6. Use the TracedModuleForward block to build a new Subgraph, and then turn this Subgraph into the Subgraph attribute of node:

    Similar to the diagram from:

%59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward[scope="__module.conv1"](%conv1), scope: __module.conv1
  block0(%self.3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d):
    %weight.5 : Tensor = prim::GetAttr[name="weight"](%self.3)
    %13 : NoneType = prim::Constant(), scope: __module.conv1
    %14 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %15 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %16 : int[] = prim::ListConstruct(%14, %15), scope: __module.conv1
    %17 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %18 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %19 : int[] = prim::ListConstruct(%17, %18), scope: __module.conv1
    %20 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %21 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %22 : int[] = prim::ListConstruct(%20, %21), scope: __module.conv1
    %23 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %24 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %25 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %26 : int[] = prim::ListConstruct(%24, %25), scope: __module.conv1
    %27 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %28 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %29 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %30 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %31 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%x, %weight.5, %13, %16, %19, %22, %23, %26, %27, %28, %29, %30, %31), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    -> (%input)

Became:

%59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward[scope="__module.conv1", Subgraph=<Graph>](%conv1, %x), scope: __module.conv1
  block0(%self.3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d):
    %weight.5 : Tensor = prim::GetAttr[name="weight"](%self.3)
    %13 : NoneType = prim::Constant(), scope: __module.conv1
    %14 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %15 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %16 : int[] = prim::ListConstruct(%14, %15), scope: __module.conv1
    %17 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %18 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %19 : int[] = prim::ListConstruct(%17, %18), scope: __module.conv1
    %20 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %21 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %22 : int[] = prim::ListConstruct(%20, %21), scope: __module.conv1
    %23 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %24 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %25 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %26 : int[] = prim::ListConstruct(%24, %25), scope: __module.conv1
    %27 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %28 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %29 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %30 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %31 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%x, %weight.5, %13, %16, %19, %22, %23, %26, %27, %28, %29, %30, %31), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    -> (%input)

block has been passed into prim::TracedModuleForward node as an attribute. You can delete it here:

%59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward[scope="__module.conv1", Subgraph=<Graph>](%conv1, %x), scope: __module.conv1

Finally, convert prim::TracedModuleForward to prim::CallMethod node, and use n - > output() - > replacealluseswith (retval); Replace the old node with the new node:

%61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)

The change of the last figure is:

graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
      %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
  %59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward_0[scope="__module.conv1"](%conv1, %x), scope: __module.conv1
  %60 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward_1[scope="__module.conv2"](%conv2, %61), scope: __module.conv2
  return (%60)

Finally, delete unused nodes:

graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
      %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
  %60 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward_0[scope="__module.conv2"](%conv2, %61), scope: __module.conv2
  return (%60)

To sum up, optimizing Pass mainly includes the following:

  • Replace prim::TracedAttr with prim::GetAttr and delete the redundant prim::TracedAttr node. (there are no users)
  • To process the prim::TracedModuleForward node, first sort out the input and output information, then convert the block into the Subgraph attribute of the operator, and finally replace the prim::TracedModuleForward node with the prim::CallMethod node.

When using the prim::CallMethod node to replace the prim::TracedModuleForward node, the results in the Subgraph of the block conversion operator will be inserted into the methods of the Submodule, so that the incoming Module will also change.

Repeat the above steps. Finally, the diagram generated by trace looks like the following:

graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
      %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
  %62 : Tensor = prim::CallMethod[name="forward"](%conv2, %61)
  return (%62)

Finally, take a look at the converted module:

module __torch__.___torch_mangle_3.TestArangeModel {
  parameters {
  }
  attributes {
    training = True
    _is_full_backward_hook = None
    conv1 = <__torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d object at 0x559c691cd470>
    conv2 = <__torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d object at 0x559c691cdcb0>
  }
  methods {
    method forward {
      graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
            %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
        %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
        %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
        %61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
        %62 : Tensor = prim::CallMethod[name="forward"](%conv2, %61)
        return (%62)
  
    }
  }
  submodules {
    module __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d {
      parameters {
        weight = ...
      }
      attributes {
        weight = ...
        training = True
        _is_full_backward_hook = None
      }
      methods {
        method forward {
          graph(%self.3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d,
                %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
            %weight.5 : Tensor = prim::GetAttr[name="weight"](%self.3)
            %2 : NoneType = prim::Constant(), scope: __module.conv1
            %3 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %4 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %5 : int[] = prim::ListConstruct(%3, %4), scope: __module.conv1
            %6 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %7 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %8 : int[] = prim::ListConstruct(%6, %7), scope: __module.conv1
            %9 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %10 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %11 : int[] = prim::ListConstruct(%9, %10), scope: __module.conv1
            %12 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %13 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %14 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %15 : int[] = prim::ListConstruct(%13, %14), scope: __module.conv1
            %16 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %17 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %18 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %19 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %20 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%x, %weight.5, %2, %5, %8, %11, %12, %15, %16, %17, %18, %19, %20), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            return (%input)
      
        }
      }
      submodules {
      }
    }
    module __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d {
      parameters {
        weight = ...
        bias = ...
      }
      attributes {
        weight = ...
        bias = ...
        training = True
        _is_full_backward_hook = None
      }
      methods {
        method forward {
          graph(%self : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d,
                %22 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
            %bias : Tensor = prim::GetAttr[name="bias"](%self)
            %weight : Tensor = prim::GetAttr[name="weight"](%self)
            %3 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %4 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %5 : int[] = prim::ListConstruct(%3, %4), scope: __module.conv2
            %6 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %7 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %8 : int[] = prim::ListConstruct(%6, %7), scope: __module.conv2
            %9 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %10 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %11 : int[] = prim::ListConstruct(%9, %10), scope: __module.conv2
            %12 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %13 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %14 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %15 : int[] = prim::ListConstruct(%13, %14), scope: __module.conv2
            %16 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %17 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %18 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %19 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %20 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %21 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%22, %weight, %bias, %5, %8, %11, %12, %15, %16, %17, %18, %19, %20), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            return (%21)
        }
      }
      submodules {
      }
    }
  }
}

summary

torch.jit.trace is a process of converting a module into a jit graph, which is described above:

  • Recursively process the attributes and subgraphs of the Module to generate the nodes of the Graph
  • Then transform, process and optimize these intermediate graph nodes, and finally get a torch that we can print out jit. Graph

A lot of work is to operate the nodes in the Graph. Students in need can go further into the source code and read more content they need along the content introduced above.

Keywords: C++ Deep Learning

Added by dinku33 on Sun, 19 Dec 2021 10:29:28 +0200