Object ObjectPtr ObjectRef relation of TVM C + + code

The Object class of TVM is the base class of many classes. For detailed analysis materials, please refer to

In depth understanding of TVM: Object family - Zhihu

In depth understanding of TVM: Object family (II) - Zhihu

TVM source code reading: the cornerstone of all things - Object class (1) - Zhihu

TVM source code reading: the cornerstone of all things - Object(2) - Zhihu

When reading TVM C + + code, the type conversion of many derived classes of Object needs to be traced back to Object/ObjectPtr/ObjectRef, so here we focus on the analysis of the relationship between the three. We can only keep the code containing the three relationships:

class TVM_DLL Object {
public:
 
    ...
protected:
 
    ...
private:
    ...
    friend class ObjectPtr;
    ...
};


template <typename T>
class ObjectPtr {
public:
    ...

private:
    Object* data_{nullptr};
    ...
    friend class Object;
    friend class ObjectRef;
    ...
};


class ObjectRef {
public:
    ...  

protected:

    ObjectPtr<Object> data_;
  
};

As you can see from the above code, the data member of ObjectPtr is data_ Is an Object pointer, and the data member of ObjectRef is data_ Is an ObjectPtr instance.

Here, ObjectPtr can use its own private member data_ Operate the corresponding Object instance. What should ObjectPtr do if it wants to operate the corresponding ObjectRef? The definition of ObjectPtr declares a friend function GetRef:

template <typename T>
class ObjectPtr {
public:
    ...

private:
    Object* data_{nullptr};
    ...
    friend class Object;
    friend class ObjectRef;
    ...
    template <typename RelayRefType, typename ObjType>
    friend RelayRefType GetRef(const ObjType* ptr);
    ...
};

template <typename RefType, typename ObjType>
inline RefType GetRef(const ObjType* ptr) {
    static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
                "Can only cast to the ref of same container type");
    if (!RefType::_type_is_nullable) {
      ICHECK(ptr != nullptr);
    }
    return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}

Here, we ignore the initial check of the GetRef function and only look at the last return sentence.

const_ Cast < Object * > (static_cast < const Object * > (PTR)) forcibly converts a pointer of ObjType type to Object type and forcibly removes const attribute. In the Object family of TVM, all class names ending with Node are inherited from Object, and class names not ending with Node are inherited from ObjectRef. Therefore, a class instance ending with Node can be forcibly converted to Object * type.

ObjectPtr < Object > (xxx) this calls the constructor of ObjectPtr to generate an ObjectPtr instance:

  explicit ObjectPtr(Object* data) : data_(data) {
    if (data != nullptr) {
      data_->IncRef();
    }
  }

IncRef is to increase the number of references of an Object, which will not be discussed in detail here

Next, RefType(xxx) calls the constructor of RefType, and the parameter is ObjectPtr type, generating an instance of RefType class. If this RefType is ObjectRef, look at the corresponding constructor:

explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}

This generates an ObjectRef type. If RefType is a subclass of ObjectRef and is the type corresponding to ObjType (such as IRModule and IRModuleNode), the corresponding Ref type can be obtained from the pointer type of a certain type.

There are many such transformations in the code, such as IRModule::FromExprInContext:

if (auto* func_node = expr.as<BaseFuncNode>()) {
    func = GetRef<BaseFunc>(func_node);

Here expr is forcibly converted to BaseFuncNode type. All types ending with Node are inherited from Object, so GetRef can be used here to convert to BaseFunc type (those not ending with Node are inherited from ObjectRef).

In this code, as is a member method of ObjectRef, which is used for the conversion between upstream and downstream of the inheritance path. The definition of as:

template <typename ObjectType>
inline const ObjectType* ObjectRef::as() const {
  if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
    return static_cast<ObjectType*>(data_.get());
  } else {
    return nullptr;
  }
}

Replace template type

inline const BaseFuncNode* ObjectRef::as() const {
  if (data_ != nullptr && data_->IsInstance<BaseFuncNode>()) {
    return static_cast<BaseFuncNode*>(data_.get());
  } else {
    return nullptr;
  }
}

In IRModule::FromExprInContext, the type of expr parameter is RelayExpr, and what we pass here from python is a Function type. class Function inherits from BaseFunction, BaseFunction inherits from RelayExpr, and the base class is ObjectRef.

IsInstance method judges whether an instance belongs to a certain type or subclass through the type index of the type. Because Function is a subclass of RelayExpr, this place returns true.

The get method of ObjectRef returns the Object * type:

 const Object* get() const { return data_.get(); }

So data_. In as() Get() is a pointer that can return BaseFuncNode (the base class is Object). In this way, the as method realizes the conversion from the non Node type of the subclass to the non Node type of the parent class on an inheritance route.

IRModule::FromExprInContext can be transferred successfully because what we pass into C + + is actually a Function type, which is a subclass of BaseFunction type. If this place is really a RelayExpr (parent class of BaseFunc), this place cannot be converted successfully (nullptr is returned).

Keywords: C++ Back-end

Added by arhunter on Fri, 18 Feb 2022 13:25:04 +0200