Making a Caffe Layer
Caffe is one of the most popular open-source neural network frameworks. It is modular, clean, and fast. Extending it is tricky but not as difficult as extending other frameworks.
Files to modify or create
Relative from the $(CAFFE_HOME)
- /src/caffe/proto/caffe.proto
- /include/caffe/common_layers.hpp or vision_layers.hpp
- /src/caffe/layer_factory.cpp
- /src/caffe/layers/new_layer.cpp
- /src/caffe/layers/new_layer.cu
- /src/caffe/test/test_new_layer.cpp
File 1: caffe.proto
You have to give a new index to your new layer. Look for next available ID
. There are two lines containing the phrase. Increment the next available ID
and define the new layer.
File 2: layer_facctory.cpp
You have to add two lines that defines switch case of layers
File 3: Layer Header
Define your layer in a common layer header file. Use either common_layers.hpp
or vision_layers.hpp
, depending on the type of the layer.
File 4 & 5 : Defining a layer
The layer has to inherit the Layer
virtual class. The virtual functions that you have to implement are the ones defined as = 0
which are
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) = 0;
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) = 0;
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
vector<Blob<Dtype>*>* bottom) = 0;
File 6 : Test File
All the layers in the caffe must have the corresponding unit test file. The unit test must thoroughly check all the functionalities implemented. Make a file /src/caffe/test/test_new_layer.cpp
and use provided caffe unit test macros.
EXPECT_NEAR
EXPECT_GE
EXPECT_LE
EXPECT_EQ
Finally, check Backprop using the GradientChecker
.
Compile and Test
Run the following lines on the $CAFFE_HOME
.
make
make test
./build/test/test_new_layer.testbin
Implementation Detail
When you implement the functions, try to use the macros and functions provided by caffe to minimize your workload.
Blob offset
When you compute the offset from the blob pointer, use the safe
offset(n,c)
function.Basic Math Functions
caffe_[mul|add|sub|div|sqr|powx|exp|abs|sin|cos|copy|scal|cpu_axpby]
Basic elementwise functions and matrix multiplication are provided in
/caffe/util/math_functions.hpp
.CUDA Macros
There are several CUDA macros that come in very handy when implementing
Forward_gpu
andBackward_gpu
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
CUDA brief summary
Since the Caffe framework heavily relies on CUDA, I’ll briefly summarize the basics of CUDA.
Function decorators
In CUDA terminology, device refers to a CUDA capable GPU and host refers to the CPU side.
There are two function decorators
__device__
and__global__
. If you put either of them in front of a function, the function is compiled as a CUDA kernel. You can call a__device__
kernel within a CUDA kernel whereas you can call a__global__
kernel from the host.A kernel function runs in parallel. There are two levels of parallelism : threads and blocks. Blocks consists of multiple threads and a collection of blocks is called as a grid. However, threads are divide into groups of 32 threads called wraps and it is best to use multiple of 32 threads.
You can specify the number of execution blocks that you will run from CPU side when you launch a kernel. For example, if you want to lanch a kernel called
kernel_function
, simply put the following on the CPU side code.kernel_function<<<N_BLOCKS, N_THREADS>>>(arguments)
. This will launchN_BLOCKS
of blocks withN_THREADS
number of threads. 1
Angle To Sine Cosine Layer
The layer takes $N \times C \times 1 \times 1$ Blob
and produces $N \times 2C \times 1 \times 1$ Blob
. The angle must be in radian (which is none of our concern since the NN weight will adjust automatically).
For each input it produces two values $sin(x)$ and $cos(x)$. Let’s concatenate $n$ sines with $n$ cosines. If we define $y_i = \sin(x_i)$ and $y_{i+C} = \cos(x_i)$, the gradient will be
$$ \begin{align} \frac{\partial E(y_i, y_{i+C}, \dots)}{\partial x_i} & = \frac{\partial E(y_i, \dots)}{\partial y_i} \frac{\partial y_i}{\partial x_i} + \frac{\partial E(y_{i + C}, \dots)}{\partial y_{i+C}} \frac{\partial y_{i + C}}{\partial x_i}\\ & = \frac{\partial E(y_i, \dots)}{\partial y_i} \frac{\partial \sin(x_i)}{\partial x_i} + \frac{\partial E(y_{i + C}, \dots)}{\partial y_(i + C)} \frac{\partial y_{i + C}}{\partial x_i}\\ & = \frac{\partial E(y_i, \dots)}{\partial y_i} \cos(x_i) - \frac{\partial E(y_{i + C}, \dots)}{\partial y_{i + C}} \sin(x_i) \end{align} $$The $\frac{\partial E(y_i, \dots)}{\partial y_i}$ is defined in top[n]->[c|g]pu_diff
angle_to_trig.cpp
#include <algorithm>
#include <functional>
#include <utility>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void AngleToTrigLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
// Takes arbitrary number of angles in radian and returns sin and cos of the
// inputs. The cosines will be append at the end of sines.
// i.e. [sin_1, sin_2, ..., sin_n, cos_1, cos_2, ...,cos_n]
CHECK_EQ(bottom[0]->height(), 1);
CHECK_EQ(bottom[0]->width(), 1);
// num, channels, height, width,
(*top)[0]->Reshape(bottom[0]->num(), 2 * bottom[0]->channels(), 1, 1);
tmp_.Reshape(bottom[0]->num(), bottom[0]->channels(), 1, 1);
}
template <typename Dtype>
void AngleToTrigLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
int n_channel = bottom[0]->channels();
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
for (int n = 0; n < bottom[0]->num(); ++n) {
// #(angles) = #(channel)
caffe_sin(n_channel, bottom_data + bottom[0]->offset(n),
top_data + (*top)[0]->offset(n));
caffe_cos(n_channel, bottom_data + bottom[0]->offset(n),
top_data + (*top)[0]->offset(n, n_channel));
}
}
template <typename Dtype>
void AngleToTrigLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
const int n_channel = (*bottom)[0]->channels();
const Dtype* top_data = top[0]->cpu_data(); // [sin(x) cos(x)], no need to compute again
const Dtype* top_diff = top[0]->cpu_diff();
if (propagate_down[0]) {
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
for (int n = 0; n < top[0]->num(); ++n) {
caffe_mul(n_channel, top_diff + top[0]->offset(n),
top_data + top[0]->offset(n, n_channel),
bottom_diff + (*bottom)[0]->offset(n));
caffe_mul(n_channel, top_diff + top[0]->offset(n, n_channel),
top_data + top[0]->offset(n),
tmp_.mutable_cpu_data());
caffe_sub(n_channel, bottom_diff + (*bottom)[0]->offset(n),
tmp_.cpu_data(),
bottom_diff + (*bottom)[0]->offset(n));
}
}
}
INSTANTIATE_CLASS(AngleToTrigLayer);
} // namespace caffe
angle_to_trig.cu
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
/* Define automatic type switching cuda sin,cos functions
* [sin,cos]f are single precision trigonometric functions
* [sin,cos] are double precision trigonometric functions
*/
template <typename Dtype>
Dtype auto_type_cos(Dtype x);
template <typename Dtype>
Dtype auto_type_sin(Dtype x);
template<>
__device__ float auto_type_cos<float>(const float x){ return cosf(x);}
template<>
__device__ double auto_type_cos<double>(const double x){ return cos(x);}
template<>
__device__ float auto_type_sin<float>(const float x){ return sinf(x);}
template<>
__device__ double auto_type_sin<double>(const double x){ return sin(x);}
template <typename Dtype>
__global__ void AngleToTrigLayerForward(const int n, const int n_channel,
const Dtype* in, Dtype* out) {
CUDA_KERNEL_LOOP(index, n) {
unsigned int data_index = index / n_channel;
unsigned int channel_index = index % n_channel;
unsigned int top_index = 2 * n_channel * data_index + channel_index;
out[top_index] =
auto_type_sin<Dtype>(in[index]);
out[top_index + n_channel] =
auto_type_cos<Dtype>(in[index]);
}
}
template <typename Dtype>
void AngleToTrigLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
const int n_channel = bottom[0]->channels();
// NOLINT_NEXT_LINE(whitespace/operators)
AngleToTrigLayerForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, n_channel, bottom_data, top_data);
CUDA_POST_KERNEL_CHECK;
}
template <typename Dtype>
__global__ void AngleToTrigLayerBackward(const int n, const int n_channel,
const Dtype* top_diff, const Dtype* top_data, Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, n) {
unsigned int data_index = index / n_channel;
unsigned int channel_index = index % n_channel;
unsigned int top_index = 2 * n_channel * data_index + channel_index;
bottom_diff[index] =
top_diff[top_index] * top_data[top_index + n_channel]
- top_diff[top_index + n_channel] * top_data[top_index];
}
}
template <typename Dtype>
void AngleToTrigLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down[0]) {
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* top_data = top[0]->gpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const int count = (*bottom)[0]->count();
const int n_channel = (*bottom)[0]->channels();
// NOLINT_NEXT_LINE(whitespace/operators)
AngleToTrigLayerBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, n_channel, top_diff, top_data, bottom_diff);
CUDA_POST_KERNEL_CHECK;
}
}
INSTANTIATE_CLASS(AngleToTrigLayer);
} // namespace caffe
Loss Layer
A loss layer does not have any top outputs since a loss is the final output. However, in caffe, you can use the top layers to set the scalers of a specific loss layer.
A scaler is fed into the loss layer using
// Scale gradient
const Dtype loss_weight = top[0]->cpu_diff()[0];
This is common practice and is used in many conventional loss layers including Euclidean Loss, Contrastive Loss, etc.
Leave a Comment