[source code analysis] PyTorch distributed Autograd -- how to cut into the engine

[source code analysis] PyTorch distributed Autograd (4) -- how to cut into the engine

catalogue

0x00 summary

Above, we saw a series of basic classes such as AutogradMetadata, DistAutogradContainer and DistAutogradContext. We know how distributed autograd transfers based on RPC, how to interact between nodes, and how to distinguish and maintain these sessions. The main purpose of this paper is to see how back propagation cuts into the engine.

0x01 previous

Let's recall the contents of the previous articles.

First, for distributed autograd, we need to track all RPCs during forward propagation to ensure that backward propagation is performed correctly. To this end, when executing RPC, we attach send and recv functions to the autograd diagram.

  • The send function is attached to the origin node of RPC, and its output edge points to the autograd function of RPC input tensor. During backward propagation, the input of the send function is received from the target and is the output of the corresponding recv function.
  • The recv function is attached to the RPC accept target node, and its input is obtained from some operators, which execute on the RPC accept target using the input tensor. During backward propagation, the output gradient of the recv function will be sent to the source node and used as the input of the send method.
  • Each send recv pair is assigned a globally unique autograd_message_id to uniquely identify the send recv pair. This is useful for finding the corresponding function on the remote node during backward propagation.
  • For RRef, whenever we call torch.distributed.rpc.RRef.to_here(), we all add an appropriate send recv pair for the tensors involved.

Secondly, in the specific code of forward propagation, we store the send and recv functions of each autograd propagation in the context. This ensures that we save references to the appropriate nodes in the autograd diagram to keep them active. In addition, this makes it easy to find the corresponding send and recv functions during backward propagation.

Thirdly, the following are some message definitions in torch / CSR / distributed / RPC / message. H:

// Messages with autograd info
FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,

// Messages to propagate gradients on the backward pass.
BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,

Earlier, we saw forward_ AUTOGRAD_ How to call req in forward propagation assumes the following code: rpc.rpc_sync("worker1", torch.add, args=(t1, t2)), the call sequence is:

  • rpc_sync call_ invoke_rpc.
  • _ invoke_rpc call_ invoke_rpc_builtin.
  • It is then called to pyRpcBuiltin and then called to sendMessageWithAutograd.
  • sendMessageWithAutograd builds forward internally_ AUTOGRAD_ Req message, finally sent using RPC.

So far, we have several questions about the overall process:

  • At the beginning of the back calculation diagram, how to initiate back propagation and pass it to the next link of back propagation?
  • In the internal link of back propagation, backward_ AUTOGRAD_ When is req called? When is the recv operation called? In the context, recvuutogradfunctions_ Where was it set up?
  • How do the above two links enter the distributed autograd engine?

Next, we will analyze these questions. The core is how to enter the dist.autograd engine.

0x02 calculation diagram

Let's start with a few examples from the calculation diagram.

2.1 general examples

First, let's take a look at ordinary computing. This is the local version of the official legend of dist.auto. You can see that the calculation diagram is composed of AddBackward0, AccumulateGrad and MulBackward0.

t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
t3 = t1 + t2
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
next_functions = t5.grad_fn.next_functions

The specific correspondence is shown in the figure below:

2.2 distributed example

Let's take a look at the distributed example. This example is the code roughly corresponding to the legend in the official design. We named torch.mul(t3, t4) t5 and added loss.

def worker0():
    # On worker 0:

    # Setup the autograd context. Computations that take
    # part in the distributed backward pass must be within
    # the distributed autograd context manager.
    with dist_autograd.context() as context_id:
      t1 = torch.rand((3, 3), requires_grad=True)
      t2 = torch.rand((3, 3), requires_grad=True)

      # Perform some computation remotely.
      t3 = rpc.rpc_sync("worker1", torch.add, args=(t1, t2))

      # Perform some computation locally based on remote result.
      t4 = torch.rand((3, 3), requires_grad=True)
      t5 = torch.mul(t3, t4)

      # Compute some loss.
      loss = t5.sum()

      # Run the backward pass.
      dist_autograd.backward(context_id, [loss])

      # Retrieve the gradients from the context.
      dist_autograd.get_gradients(context_id)

      print(loss)

Under distributed, t3 runs in different places.

  • T5 corresponds to mul, t5.grad_fn is < mulbackward0 object at 0x7fbf18d297b8 >.
  • t3.grad_fn is < CppFunction object at 0x7fbf18d11a20 >, that is, recv corresponds to CppFunction.
  • loss is tensor(5.5680, grad_fn =).
  • The rest are None.

Let's show the design legend again. The above example code is worker 0 on the left of the figure below. t3 actually runs in worker 1. You can see some characteristics in the distributed context.

2.3 Distributed Annotation version

For better explanation, we printed some log s as comments.

def _verify_send(send_function):
    print(send_function.name())
    next_funcs = send_function.next_functions
    print(next_funcs[0][0].name())
    print(next_funcs[1][0].name())

def _verify_recv(recv_function):
    print(recv_function.name())
    next_funcs = recv_function.next_functions
    print(len(next_funcs))

def worker0():
    # On worker 0:

    # Setup the autograd context. Computations that take
    # part in the distributed backward pass must be within
    # the distributed autograd context manager.
    with dist_autograd.context() as context_id:
      t1 = torch.rand((3, 3), requires_grad=True)
      t2 = torch.rand((3, 3), requires_grad=True)

      # Perform some computation remotely.
      #t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
      t3 = rpc.rpc_sync("worker1", torch.add, args=(t1, t2))

      # Perform some computation locally based on remote result.
      t4 = torch.rand((3, 3), requires_grad=True)
      t5 = torch.mul(t3, t4)

      # Compute some loss.
      loss = t5.sum()

      print("--- send ---")
      ctx = dist_autograd._retrieve_context(context_id)
      send_functions = ctx._send_functions()
      _verify_send(list(send_functions.values())[0])

      print("--- loss ---")
      print(loss)
      mul_func = loss.grad_fn.next_functions[0][0]
      print(mul_func.name())
      next_funcs = mul_func.next_functions
      print(next_funcs[0][0].name())
      print(next_funcs[1][0].name())

      print("---- recv ----")
      recv_functions = ctx._recv_functions()
      _verify_recv(list(recv_functions.values())[0])

      # Run the backward pass.
      dist_autograd.backward(context_id, [loss])

      # Retrieve the gradients from the context.
      dist_autograd.get_gradients(context_id)

The print result is:

--- send ---
torch::distributed::autograd::SendRpcBackward
torch::autograd::AccumulateGrad
torch::autograd::AccumulateGrad
        
--- loss ---
tensor(3.5197, grad_fn=<SumBackward0>)
MulBackward0
torch::distributed::autograd::RecvRpcBackward
torch::autograd::AccumulateGrad
        
---- recv ----
torch::distributed::autograd::RecvRpcBackward

After adding the distributed correlation operator, the legend is as follows:

0x03 back propagation

Next, let's see how to enter the dist autograd engine. Combined with our legend, it is:

  • How does worker 0 actively initiate back propagation and then enter the distributed engine?
  • How does woker 0 internally initiate a back propagation request for worker 1?
  • How does worker 1 passively accept back propagation messages and then enter the distributed engine?

3.1 initiate back propagation

Let's find out how to initiate back propagation, in bottom-up order. There are also two types:

  • One is to initiate actively, such as calling the backward method on the loss of worker 0 in the figure above.
  • One is internal implicit initiation. For example, t3 in worker 0 above tells worker 1 through recv that you should start back propagation.

3.1.1 external initiative

3.1.1.1 example

Let's look at how the backward of distributed autograd can be called actively from top to bottom. For example, the call will be shown in the example.

def worker0():
    # On worker 0:

    with dist_autograd.context() as context_id:
      t1 = torch.rand((3, 3), requires_grad=True)
      t2 = torch.rand((3, 3), requires_grad=True)

      # Perform some computation remotely.
      t3 = rpc.rpc_sync("worker1", torch.add, args=(t1, t2))

      # Perform some computation locally based on remote result.
      t4 = torch.rand((3, 3), requires_grad=True)
      t5 = torch.mul(t3, t4)

      # Compute some loss.
      loss = t5.sum()

      # Run the backward pass.
      dist_autograd.backward(context_id, [loss]) // This will call
3.1.1.2 C + + world

In torch/_ C/_ distributed_ In autograd.pyi, we can see the following comments:

# This module is defined in torch/csrc/distributed/autograd/init.cpp

So let's go to the torch / CSR / distributed / autograd / init.cpp file.

Part of the code is omitted. Here you can see that the context is generated and backward and get are defined_ Gradients, etc.

PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) {
  auto autograd_module =
      THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd"));
  auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
  auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
  auto m = torch_C_m.def_submodule("_distributed_autograd", "distributed autograd bindings");
  auto module = py::handle(m).cast<py::module>();

  auto distAutogradContext =
      shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
          .def(
              "_context_id",
              &DistAutogradContext::contextId,
              py::call_guard<py::gil_scoped_release>())
          .def(
              "_recv_functions",
              [](const DistAutogradContext& ctx) {
                std::map<int64_t, py::object> funcs;
                for (const auto& map_entry : ctx.recvFunctions()) {
                  funcs.emplace(
                      map_entry.first,
                      py::reinterpret_steal<py::object>(
                          torch::autograd::functionToPyObject(
                              map_entry.second)));
                }
                return funcs;
              })
          .def(
              "_send_functions",
              [](const ContextPtr& ctx) {
                std::map<int64_t, py::object> funcs;
                for (const auto& map_entry : ctx->sendFunctions()) {
                  funcs.emplace(
                      map_entry.first,
                      py::reinterpret_steal<py::object>(
                          torch::autograd::functionToPyObject(
                              map_entry.second)));
                }
                return funcs;
              })
          .def("_known_worker_ids", &DistAutogradContext::getKnownWorkerIds);

  module.def(
      "_new_context",
      []() -> const ContextPtr {
        return DistAutogradContainer::getInstance().newContext();
      },
      py::return_value_policy::reference);

  py::options options;
  options.disable_function_signatures();

  module.def(
      "backward",
      backward,
      py::arg("contextId"),
      py::arg("roots"),
      py::arg("retain_graph") = false,
      py::call_guard<py::gil_scoped_release>());

  module.def(
      "get_gradients",
      [](int64_t contextId) -> py::dict {
        const auto& autogradContext =
            DistAutogradContainer::getInstance().retrieveContext(contextId);
        return torch::jit::toPyObject(IValue(autogradContext->getGradients()));
      },
      py::arg("context_id"));

  Py_RETURN_TRUE;
}
} // namespace

The specific backward is defined in torch / CSR / distributed / autograd / autograd.cpp.

void backward(
    int64_t context_id,
    const variable_list& roots,
    bool retain_graph) {
  RECORD_FUNCTION(
      kDistAutogradBackwardProfilingKey, std::vector<c10::IValue>());
  try {
    DistEngine::getInstance().execute(context_id, roots, retain_graph);
  } catch (std::exception& e) {
    // FIXME: crashes if exception type is not RuntimeError
    throw std::runtime_error(e.what());
  }
}

As you can see, DistEngine::getInstance().execute(context_id, roots, retain_graph) will be called to complete the back propagation. This enters the engine.

3.1.2 internal implicit initiation

Because it is implicitly initiated, the code is relatively hidden. This time, we use the bottom-up method to peel silk and cocoon. We know that if backpropagation is required between nodes, backward will be sent_ AUTOGRAD_ Req, so we from BACKWARD_AUTOGRAD_REQ starts the search.

3.1.2.1 BACKWARD_AUTOGRAD_REQ

In torch / CSR / distributed / autograd / RPC_ messages/propagate_ gradients_ In req.cpp, PropagateGradientsReq::toMessageImpl will call BACKWARD_AUTOGRAD_REQ.

Message PropagateGradientsReq::toMessageImpl() && {
  std::vector<at::IValue> ivalues;
  // Add all the grad tensors.
  for (const auto& grad : grads_) {
    ivalues.emplace_back(grad);
  }

  // Now add autograd metadata.
  ivalues.emplace_back(autogradMetadata_.autogradContextId);
  ivalues.emplace_back(autogradMetadata_.autogradMessageId);

  // Add retain graph.
  ivalues.emplace_back(retainGraph_);

  // Now pickle using JIT pickler.
  std::vector<torch::Tensor> tensorTable;
  std::vector<char> payload =
      jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);

  return Message(
      std::move(payload),
      std::move(tensorTable),
      MessageType::BACKWARD_AUTOGRAD_REQ); // It will be used here
}
3.1.2.2 PropagateGradientsReq

Continue to find who sent the BACKWARD_AUTOGRAD_REQ, who called toMessageImpl? It was originally in torch / CSR / distributed / autograd / functions / recvrpc_ Backward.cpp PropagateGradientsReq is built here, and toMessage will be used to build a message. That is, the call of RecvRpcBackward will send a BACKWARD_AUTOGRAD_REQ.

variable_list RecvRpcBackward::apply(variable_list&& grads) { // Call Node
  std::vector<Variable> outputGrads;
  for (size_t i = 0; i < grads.size(); i++) {
    const auto& grad = grads[i];
    if (grad.defined()) {
      outputGrads.emplace_back(grad);
    } else {
      // Put in zeros for a tensor with no grad.
      outputGrads.emplace_back(input_metadata(i).zeros_like());
    }
  }

  auto sharedContext = autogradContext_.lock();
  // Send the gradients over the wire and record the future in the autograd
  // context.
  PropagateGradientsReq gradCall( // Here, PropagateGradientsReq is built
      autogradMetadata_,
      outputGrads,
      sharedContext->retrieveGraphTask()->keep_graph_);

  // Send the gradients over to the appropriate node.
  auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
  auto jitFuture = rpcAgent->send( // Send it to the next node in the backward propagation process
      rpcAgent->getWorkerInfo(fromWorkerId_),
      std::move(gradCall).toMessage(), // PropagateGradientsReq::toMessageImpl is called here
      rpc::kUnsetRpcTimeout,
      deviceMap_);

  // Record the future in the context.
  sharedContext->addOutstandingRpc(jitFuture);

  // 'recv' function sends the gradients over the wire using RPC, it doesn't
  // need to return anything for any downstream autograd function.
  return variable_list();
}

Therefore, we know that during the execution of RecvRpcBackward, backward will be sent_ AUTOGRAD_ Req, send to the next node. Where to call RecvRpcBackward? We will introduce it in the next DistEngine.

At this time, the details are as follows: t3 of worker 0 sends backup to worker 1_ AUTOGRAD_ Req message.

                                                                +
                                                       worker 0 | worker 1
                                                                |
                                                                |
 RecvRpcBackward                         PropagateGradientsReq  |
       +                                          +             |
       |                                          |             |
       |                                          |             |
       |                                          |             |
       v                                          |             |
                                                  |             |
     apply()                                      |             |
       +                                          |             |
       |                                          v             |
       |                                                        |
       | +------------------------------>  toMessageImpl        |
       |                                          +             |
       |                                          |             |
       |   Message(BACKWARD_AUTOGRAD_REQ)         |             |
       | <----------------------------------------+             |
       |                                                        |
       |                                                        |
       v                                                        |
                                                                |
rpcAgent+>send(Message)  +-------------------------------------------->
       +                             BACKWARD_AUTOGRAD_REQ      |
       |                                                        |
       |                                                        |
       v                                                        |
                                                                +

The corresponding example diagram is:

3.2 acceptance of back propagation

Let's look at how the receiver handles back propagation. We return to worker 1 again, that is, how the send node in the figure accepts back propagation messages.

3.2.1 accept message

When generating TensorPipeAgent, configure RequestCallbackImpl as a callback function. This is the unified response function of agent. When we mentioned the agent receiving logic earlier, we will enter the requestcallback nopthon:: processrpc function. You can see a pair of backward_ AUTOGRAD_ Processing logic of req.

This is the normal process of RPC.

void RequestCallbackNoPython::processRpc(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {

  switch (messageType) {

    case MessageType::BACKWARD_AUTOGRAD_REQ: { 
      processBackwardAutogradReq(rpc, messageId, responseFuture); // Call here
      return;
    };

3.2.2 message processing

In processBackwardAutogradReq:

  • Get DistAutogradContainer.
  • Get the context, which was previously established in the forward propagation process. As can be seen from the above, in this legend, each autograd propagation in worker 0 and worker 1 shares the same context context id.
  • Obtain the corresponding SendRpcBackward from the context through the context id of the sender. Here we see how context is used.
  • Use sendFunction as a parameter and call executeSendFunctionAsync for engine processing.
void RequestCallbackNoPython::processBackwardAutogradReq(
    RpcCommandBase& rpc,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture) const {
  auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
  const auto& autogradMetadata = gradientsCall.getAutogradMetadata();

  // Retrieve the appropriate autograd context.
  auto autogradContext = DistAutogradContainer::getInstance().retrieveContext(
      autogradMetadata.autogradContextId); // Get the context id of the sender

  // Lookup the appropriate 'send' function to enqueue.
  std::shared_ptr<SendRpcBackward> sendFunction = // sendFunction is obtained according to the sender context id and message id
      autogradContext->retrieveSendFunction(autogradMetadata.autogradMessageId);

  // Attach the gradients to the send function.
  sendFunction->setGrads(gradientsCall.getGrads()); // Set gradient

  // Now execute the autograd graph using the "distributed engine."
  auto execFuture = DistEngine::getInstance().executeSendFunctionAsync( // Call engine
      autogradContext, sendFunction, gradientsCall.retainGraph());

  // Our response is satisfied when the rpcs come back.
  execFuture->addCallback([responseFuture, messageId](JitFuture& execFuture) {
    if (!execFuture.hasError()) {
      Message m = std::move(PropagateGradientsResp()).toMessage();
      m.setId(messageId);
      responseFuture->markCompleted(
          IValue(c10::make_intrusive<Message>(std::move(m))));
    } else {
      responseFuture->setError(execFuture.exception_ptr());
    }
  });
}

In DistEngine::executeSendFunctionAsync of worker 1, it will perform rolling processing and finally send BACKWARD_AUTOGRAD_REQ is downstream of its back propagation, so we continue to modify and expand the example diagram and add a BACKWARD_AUTOGRAD_REQ.

3.3 summary

We can see that there are two ways to enter dist autograd engine and start back propagation:

  • One is that the sample code explicitly calls backward, and then calls DistEngine::getInstance().execute, which is worker 0.
  • One is to passively call DistEngine::getInstance().executeSendFunctionAsync, that is, worker 1 (of course, the send of worker 0 also corresponds to a passive call).

Added by steply on Sat, 04 Dec 2021 06:23:33 +0200