Pytorch Extension with a Makefile

Pytorch is a great neural network library that has both flexibility and power. Personally, I think it is the best neural network library for prototyping (advanced) dynamic neural networks fast and deploying it to applications.

Recently, pytorch was upgraded to version 1.0 and introduced the ATen tensor library for all backend and c++ custom extension. Before the c++ extension, it supported CFFI (C Foreign Function Import) for a custom extension.

As an avid CUDA developer, I created multiple projects to speed up custom pytorch layers using the CFFI interface. However, wrapping functions with a non object oriented program (C) sometimes led to a ridiculous overhead when complex objects are required. Now that it supports the latest technology from 2011, c++11, we can now use object oriented programming for pytorch extensions!

In this tutorial, I will cover soe drawbacks of the current setuptools and will show you how to use a Makefile for pytorch cpp extension development. The source codes for the tutorial can be found here.

Before you proceed, please read the official Pytorch CPP extension guide which provides an extensive and useful tutorial for how to create a C++ extension with ATen.

Drawbacks of Setuptools for Development

However, the setuptool is not really flexible as it primarily focuses on the deployment of a project. Thus, it lacks a lot of features that are essential for fast development. Let’s delve into few scenarios that I encountered while I was porting my pytorch cffi extensions to cpp extensions.

Compile only updated files

When you develop a huge project, you don’t want to compile the entire project everytime you make a small change. However, if you use the setuptool, it creates objects for ALL source files, everytime you make a change. This becomes extremely cumbersome especially when your project gets larger.

However, Makefile allows you to cache all object files as you have control over all files and compile only the files that are updated. This is extremely useful if you made a small change to one file and want to quickly debug your project.

Parallel Compilation

Another problem with the setuptool is that it compiles files sequentially. When your project gets huge, you might want to compile a lot of files in parallel. With a Makefile, you can parallelize compilation with the -j# flag. For example, if you type make -j8, it would compile 8 files in parallel.

Debugging

The current pytorch c++ extension does not allow debugging even with the debug flag. Instead, with a Makefile, you could pass -g (or -g -G for nvcc) with ease. In the Makefile, uncomment the line 3 DEBUG=1 and the line 20 of setup.py.

Making a pytorch extension with a Makefile

Now that we covered some of advantages of using a Makefile for a pytorch cpp extension, let’s get into the details of how to make one.

Creating Objects and Functions

As an example, in this tutorial, we will create a class and a cuda function that are callable in python. First, let’s make a simple class that provides a setter and a getter for a private variable key_.

class Foo {
private:
  uint64_t key_;

public:
  void setKey(uint64_t key);
  uint64_t getKey();
  std::string toString() const {
    return "< Foo, key: " + std::to_string(key_) + " > ";
  };
};

We will fill out the setter and the getter functions in foo.cpp. Next, I created a simple CUDA function that adds two vectors and returns results in a new at::Tensor.

template <typename Dtype>
__global__ void sum(Dtype *a, Dtype *b, Dtype *c, int N) {
  int i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i <= N) {
    c[i] = a[i] + b[i];
  }
}

template <typename Dtype>
void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
                  cudaStream_t stream) {
  sum<Dtype>
      <<<GET_BLOCKS(N), CUDA_NUM_THREADS, 0, stream>>>(in_a, in_b, out_c, N);

  cudaError_t err = cudaGetLastError();
  if (cudaSuccess != err)
    throw std::runtime_error(Formatter()
                             << "CUDA kernel failed : " << std::to_string(err));
}

Note that I call std::runtime_error when it gives an error. Pybind11 automatically converts std exceptions to python exception types. For example, the std::runtime_error will be mapped to RuntimeError in python. This prevents the system from crashing and allows the python to handle errors gracefully. More error handling with pybind11 can be found at here.

Bridging CPP with Pybind

Pytorch passes tensors as the at::Tensor type. To extract the mutable raw pointer, use .data<Dtype>(). For example, if you want to extract the raw pointer from a variable A of type float, use A.data<float>(). In addition, if you want to use the CUDA stream for the current context, use the function at::cuda::getCurrentCUDAStream().

template <typename Dtype>
void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) {
  int N = in_a.numel();
  if (N != in_b.numel())
    throw std::invalid_argument(Formatter()
                                << "Size mismatch A.numel(): " << in_a.numel()
                                << ", B.numel(): " << in_b.numel());

  out_c.resize_({N});

  AddGPUKernel<Dtype>(in_a.data<Dtype>(), in_b.data<Dtype>(),
                      out_c.data<Dtype>(), N, at::cuda::getCurrentCUDAStream());
}

The above function can be directly called from python with pybind11. Now, let’s bind the cpp function and the class with python.

namespace py = pybind11;

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
  std::string name = std::string("Foo");
  py::class_<Foo>(m, name.c_str())
      .def(py::init<>())
      .def("setKey", &Foo::setKey)
      .def("getKey", &Foo::getKey)
      .def("__repr__", [](const Foo &a) { return a.toString(); });

  m.def("AddGPU", &AddGPU<float>);
}

For classes, need to use py::class_<CLASS> to let the pybind to know it is a class. Then, define functions that you want to expose to python with .def. For functions, you can simply attach the function using .def directly.

Compiling the project

Now that we have all source files ready, let’s compile them. First, we will make an archive library that contains all classes and functions. Then, we can compile the file that bind all functions and classes with pybind11 with the setuptools and load it in python.

Finding the Arguments and the Include Paths

First, we have to find the right arguments used to compile the pytorch extension. It is actually easy to find since when you compile your project using the setuptools, you can see the actual compilation command that it invokes and we can deduce what would be required to make a project from a Makefile. Extra arguments that it uses for pytorch extensions are

-DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=$(EXTENSION_NAME) -D_GLIBCXX_USE_CXX11_ABI=0

In addition, we need to find headers. The pytorch provides CPP extensions with setuptools and we could see how it finds the headers and libraries. In torch.utils.cpp_extension you can find the function include_paths, which provides all header paths. We only need to pass it to the Makefile. Within a Makefile, we can run a python command and get the paths like the following. (2019-07-03 Pytorch now supports the ABI flag explicitly. See below.)

PYTHON_HEADER_DIR := $(shell python -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())')

Note that the command above prints out all paths line by line, so, in the end, we can iterate over the paths in the Makefile to prepend -I. The final makefile can be found at here.

Archive Libraries

Once we build library files, we create an archive file and link it to the main pybind entries. We can do so by

	ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS)

Compiling the bind file

When the archive library is ready, we can finally compile the bind file that will linke the classes and functions with the binding.

from setuptools import setup

setup(
    ...
    ext_modules=[
        CUDAExtension(
            name='MakePytorchBackend',
            include_dirs=['./'],
            sources=[
                'pybind/bind.cpp',
            ],
            libraries=['make_pytorch'],
            library_dirs=['objs'],
            # extra_compile_args=['-g']
        )
    ],
    cmdclass={'build_ext': BuildExtension},
    ...
)

Finally, if we automatically call the setup.py in python, we can only issue make -j8 to compile all files, binding and install it in the python library.

all: $(STATIC_LIB)
	python setup.py install --force

Update 2019-07-03

Pytorch v1.1 now provides ABI flag explicitly. You can access the ABI flag using torch._C._GLIBCXX_USE_CXX11_ABI. You may refer to the use case in Makefile on the MinkowskiEngine Makefile.

Leave a Comment