Making a Caffe Layer

Caffe is one of the most popular open-source neural network frameworks. It is modular, clean, and fast. Extending it is tricky but not as difficult as extending other frameworks.

Files to modify or create

Relative from the $(CAFFE_HOME)

  • /src/caffe/proto/caffe.proto
  • /include/caffe/common_layers.hpp or vision_layers.hpp
  • /src/caffe/layer_factory.cpp
  • /src/caffe/layers/new_layer.cpp
  • /src/caffe/layers/new_layer.cu
  • /src/caffe/test/test_new_layer.cpp

File 1: caffe.proto

You have to give a new index to your new layer. Look for next available ID. There are two lines containing the phrase. Increment the next available ID and define the new layer.

File 2: layer_facctory.cpp

You have to add two lines that defines switch case of layers

File 3: Layer Header

Define your layer in a common layer header file. Use either common_layers.hpp or vision_layers.hpp, depending on the type of the layer.

File 4 & 5 : Defining a layer

The layer has to inherit the Layer virtual class. The virtual functions that you have to implement are the ones defined as = 0 which are

virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
  vector<Blob<Dtype>*>* top) = 0;
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
  vector<Blob<Dtype>*>* top) = 0;
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
  const vector<bool>& propagate_down,
  vector<Blob<Dtype>*>* bottom) = 0;

File 6 : Test File

All the layers in the caffe must have the corresponding unit test file. The unit test must thoroughly check all the functionalities implemented. Make a file /src/caffe/test/test_new_layer.cpp and use provided caffe unit test macros.

EXPECT_NEAR
EXPECT_GE
EXPECT_LE
EXPECT_EQ

Finally, check Backprop using the GradientChecker.

Compile and Test

Run the following lines on the $CAFFE_HOME.

make
make test
./build/test/test_new_layer.testbin

Implementation Detail

When you implement the functions, try to use the macros and functions provided by caffe to minimize your workload.

  • Blob offset

    When you compute the offset from the blob pointer, use the safe offset(n,c) function.

  • Basic Math Functions

    caffe_[mul|add|sub|div|sqr|powx|exp|abs|sin|cos|copy|scal|cpu_axpby]

    Basic elementwise functions and matrix multiplication are provided in /caffe/util/math_functions.hpp.

  • CUDA Macros

    There are several CUDA macros that come in very handy when implementing Forward_gpu and Backward_gpu

    // CUDA: grid stride looping
    #define CUDA_KERNEL_LOOP(i, n) \
      for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
           i < (n); \
           i += blockDim.x * gridDim.x)
    

CUDA brief summary

Since the Caffe framework heavily relies on CUDA, I’ll briefly summarize the basics of CUDA.

  • Function decorators

    In CUDA terminology, device refers to a CUDA capable GPU and host refers to the CPU side.

    There are two function decorators __device__ and __global__. If you put either of them in front of a function, the function is compiled as a CUDA kernel. You can call a __device__ kernel within a CUDA kernel whereas you can call a __global__ kernel from the host.

    A kernel function runs in parallel. There are two levels of parallelism : threads and blocks. Blocks consists of multiple threads and a collection of blocks is called as a grid. However, threads are divide into groups of 32 threads called wraps and it is best to use multiple of 32 threads.

    You can specify the number of execution blocks that you will run from CPU side when you launch a kernel. For example, if you want to lanch a kernel called kernel_function, simply put the following on the CPU side code. kernel_function<<<N_BLOCKS, N_THREADS>>>(arguments). This will launch N_BLOCKS of blocks with N_THREADS number of threads. 1

Angle To Sine Cosine Layer

The layer takes $N \times C \times 1 \times 1$ Blob and produces $N \times 2C \times 1 \times 1$ Blob. The angle must be in radian (which is none of our concern since the NN weight will adjust automatically).

For each input it produces two values $sin(x)$ and $cos(x)$. Let’s concatenate $n$ sines with $n$ cosines. If we define $y_i = \sin(x_i)$ and $y_{i+C} = \cos(x_i)$, the gradient will be

$$ \begin{align} \frac{\partial E(y_i, y_{i+C}, \dots)}{\partial x_i} & = \frac{\partial E(y_i, \dots)}{\partial y_i} \frac{\partial y_i}{\partial x_i} + \frac{\partial E(y_{i + C}, \dots)}{\partial y_{i+C}} \frac{\partial y_{i + C}}{\partial x_i}\\ & = \frac{\partial E(y_i, \dots)}{\partial y_i} \frac{\partial \sin(x_i)}{\partial x_i} + \frac{\partial E(y_{i + C}, \dots)}{\partial y_(i + C)} \frac{\partial y_{i + C}}{\partial x_i}\\ & = \frac{\partial E(y_i, \dots)}{\partial y_i} \cos(x_i) - \frac{\partial E(y_{i + C}, \dots)}{\partial y_{i + C}} \sin(x_i) \end{align} $$

The $\frac{\partial E(y_i, \dots)}{\partial y_i}$ is defined in top[n]->[c|g]pu_diff

angle_to_trig.cpp

#include <algorithm>
#include <functional>
#include <utility>
#include <vector>

#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {

template <typename Dtype>
void AngleToTrigLayer<Dtype>::Reshape(
  const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
  // Takes arbitrary number of angles in radian and returns sin and cos of the
  // inputs. The cosines will be append at the end of sines.
  // i.e. [sin_1, sin_2, ..., sin_n, cos_1, cos_2, ...,cos_n]
  CHECK_EQ(bottom[0]->height(), 1);
  CHECK_EQ(bottom[0]->width(), 1);
  // num, channels, height, width,
  (*top)[0]->Reshape(bottom[0]->num(), 2 * bottom[0]->channels(), 1, 1);
  tmp_.Reshape(bottom[0]->num(), bottom[0]->channels(), 1, 1);
}

template <typename Dtype>
void AngleToTrigLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
    vector<Blob<Dtype>*>* top) {

  int n_channel = bottom[0]->channels();
  const Dtype* bottom_data = bottom[0]->cpu_data();
  Dtype* top_data = (*top)[0]->mutable_cpu_data();

  for (int n = 0; n < bottom[0]->num(); ++n) {
    // #(angles) = #(channel)
    caffe_sin(n_channel, bottom_data + bottom[0]->offset(n),
            top_data + (*top)[0]->offset(n));
    caffe_cos(n_channel, bottom_data + bottom[0]->offset(n),
            top_data + (*top)[0]->offset(n, n_channel));
  }
}

template <typename Dtype>
void AngleToTrigLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
  const int n_channel = (*bottom)[0]->channels();
  const Dtype* top_data = top[0]->cpu_data(); // [sin(x) cos(x)], no need to compute again
  const Dtype* top_diff = top[0]->cpu_diff();
  if (propagate_down[0]) {
    Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
    for (int n = 0; n < top[0]->num(); ++n) {
      caffe_mul(n_channel, top_diff + top[0]->offset(n),
              top_data + top[0]->offset(n, n_channel),
              bottom_diff + (*bottom)[0]->offset(n));
      caffe_mul(n_channel, top_diff + top[0]->offset(n, n_channel),
              top_data + top[0]->offset(n),
              tmp_.mutable_cpu_data());
      caffe_sub(n_channel, bottom_diff + (*bottom)[0]->offset(n),
              tmp_.cpu_data(),
              bottom_diff + (*bottom)[0]->offset(n));
    }
  }
}

INSTANTIATE_CLASS(AngleToTrigLayer);

}  // namespace caffe

angle_to_trig.cu

#include <algorithm>
#include <vector>

#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {


/* Define automatic type switching cuda sin,cos functions
 * [sin,cos]f are single precision trigonometric functions
 * [sin,cos] are double precision trigonometric functions
 */
template <typename Dtype>
Dtype auto_type_cos(Dtype x);

template <typename Dtype>
Dtype auto_type_sin(Dtype x);

template<>
__device__ float auto_type_cos<float>(const float x){  return cosf(x);}
template<>
__device__ double auto_type_cos<double>(const double x){ return cos(x);}
template<>
__device__ float auto_type_sin<float>(const float x){  return sinf(x);}
template<>
__device__ double auto_type_sin<double>(const double x){  return sin(x);}


template <typename Dtype>
__global__ void AngleToTrigLayerForward(const int n, const int n_channel,
    const Dtype* in, Dtype* out) {
  CUDA_KERNEL_LOOP(index, n) {
    unsigned int data_index = index / n_channel;
    unsigned int channel_index = index % n_channel;
    unsigned int top_index = 2 * n_channel * data_index + channel_index;

    out[top_index] =
        auto_type_sin<Dtype>(in[index]);
    out[top_index + n_channel] =
        auto_type_cos<Dtype>(in[index]);
  }
}

template <typename Dtype>
void AngleToTrigLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
    vector<Blob<Dtype>*>* top) {
  const Dtype* bottom_data = bottom[0]->gpu_data();
  Dtype* top_data = (*top)[0]->mutable_gpu_data();
  const int count = bottom[0]->count();
  const int n_channel = bottom[0]->channels();
  // NOLINT_NEXT_LINE(whitespace/operators)
  AngleToTrigLayerForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
      count, n_channel, bottom_data, top_data);
  CUDA_POST_KERNEL_CHECK;
}

template <typename Dtype>
__global__ void AngleToTrigLayerBackward(const int n, const int n_channel, 
    const Dtype* top_diff, const Dtype* top_data, Dtype* bottom_diff) {
  CUDA_KERNEL_LOOP(index, n) {
    unsigned int data_index = index / n_channel;
    unsigned int channel_index = index % n_channel;
    unsigned int top_index = 2 * n_channel * data_index + channel_index;
    bottom_diff[index] =
        top_diff[top_index] * top_data[top_index + n_channel]
        - top_diff[top_index + n_channel] * top_data[top_index];
  }
}

template <typename Dtype>
void AngleToTrigLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
    const vector<bool>& propagate_down,
    vector<Blob<Dtype>*>* bottom) {
  if (propagate_down[0]) {
    const Dtype* top_diff = top[0]->gpu_diff();
    const Dtype* top_data = top[0]->gpu_data();
    Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
    const int count = (*bottom)[0]->count();
    const int n_channel = (*bottom)[0]->channels();
    // NOLINT_NEXT_LINE(whitespace/operators)
    AngleToTrigLayerBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
        count, n_channel, top_diff, top_data, bottom_diff);
    CUDA_POST_KERNEL_CHECK;
  }
}


INSTANTIATE_CLASS(AngleToTrigLayer);


}  // namespace caffe

Loss Layer

A loss layer does not have any top outputs since a loss is the final output. However, in caffe, you can use the top layers to set the scalers of a specific loss layer.

A scaler is fed into the loss layer using

// Scale gradient
const Dtype loss_weight = top[0]->cpu_diff()[0];

This is common practice and is used in many conventional loss layers including Euclidean Loss, Contrastive Loss, etc.

Tags:

Categories:

Created:

Updated:

Leave a Comment