Skip to content

Commit

Permalink
Regularization "Upper bound on L2 norm of the incoming weight vector …
Browse files Browse the repository at this point in the history
…for each output neuron" added
  • Loading branch information
milakov committed Jul 27, 2013
1 parent 6e62b97 commit abf9638
Show file tree
Hide file tree
Showing 29 changed files with 1,022 additions and 20 deletions.
175 changes: 175 additions & 0 deletions nnforge/cuda/convolution_weight_vector_bound_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright 2011-2013 Maxim Milakov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "convolution_weight_vector_bound_cuda.h"

#include "../convolution_layer.h"

extern __shared__ float arr[];
template<bool single_item_per_thread>
__global__ void convolution_normalize_weights_to_max_l2_norm_kernel(
float * __restrict weights,
const float * __restrict weights_read_copy,
float max_l2_norm_squared,
int incoming_weight_count_per_output_neuron,
int output_feature_map_count,
int min_iteration_count)
{
int thread_id = threadIdx.x;
int output_feature_map_id = blockIdx.y;
int entry_id = blockIdx.z;
int threadblock_size = blockDim.x;

int base_weight_id = (entry_id * output_feature_map_count + output_feature_map_id) * incoming_weight_count_per_output_neuron;
float current_val;
float sum = 0.0F;
int current_weight_id = thread_id;
for(int i = 0; i < min_iteration_count; ++i)
{
current_val = weights_read_copy[base_weight_id + current_weight_id];
sum += current_val * current_val;
current_weight_id += threadblock_size;
}
if (current_weight_id < incoming_weight_count_per_output_neuron)
{
current_val = weights_read_copy[base_weight_id + current_weight_id];
sum += current_val * current_val;
}
arr[thread_id] = sum;
__syncthreads();

int t_add_elems = threadblock_size >> 1;
int t_working_elems = (threadblock_size + 1) >> 1;
while (t_add_elems > 0)
{
if (thread_id < t_add_elems)
arr[thread_id] += arr[thread_id + t_working_elems];
t_add_elems = t_working_elems >> 1;
t_working_elems = (t_working_elems + 1) >> 1;
__syncthreads();
}

sum = arr[0];
if (sum <= max_l2_norm_squared)
return;

float mult = rsqrtf(__fdividef(sum, max_l2_norm_squared));

if (single_item_per_thread)
{
if (thread_id < incoming_weight_count_per_output_neuron)
weights[base_weight_id + thread_id] = current_val * mult;
}
else
{
int current_weight_id = thread_id;
for(int i = 0; i < min_iteration_count; ++i)
{
weights[base_weight_id + current_weight_id] += weights_read_copy[base_weight_id + current_weight_id] * mult;
current_weight_id += threadblock_size;
}
if (current_weight_id < incoming_weight_count_per_output_neuron)
weights[base_weight_id + current_weight_id] += weights_read_copy[base_weight_id + current_weight_id] * mult;
}
}

namespace nnforge
{
namespace cuda
{
convolution_weight_vector_bound_cuda::convolution_weight_vector_bound_cuda()
{
}

convolution_weight_vector_bound_cuda::~convolution_weight_vector_bound_cuda()
{
}

const boost::uuids::uuid& convolution_weight_vector_bound_cuda::get_uuid() const
{
return convolution_layer::layer_guid;
}

void convolution_weight_vector_bound_cuda::enqueue_normalize_weights(
cudaStream_t stream_id,
const weight_vector_bound& bound,
const std::vector<cuda_linear_buffer_device_smart_ptr>& data,
const std::vector<cuda_linear_buffer_device_smart_ptr>& additional_buffers,
unsigned int entry_count)
{
int threadblock_size = get_threadblock_size(incoming_weight_count_per_output_neuron);
dim3 grid_size(1, output_feature_map_count, entry_count);
dim3 block_size(threadblock_size, 1, 1);
int min_iteration_count = incoming_weight_count_per_output_neuron / threadblock_size;
int smem_size = threadblock_size * sizeof(float);
float max_l2_norm_squared = bound.max_l2_norm * bound.max_l2_norm;

if (incoming_weight_count_per_output_neuron <= threadblock_size)
{
convolution_normalize_weights_to_max_l2_norm_kernel<true><<<grid_size, block_size, smem_size, stream_id>>>(
*data[0],
*data[0],
max_l2_norm_squared,
incoming_weight_count_per_output_neuron,
output_feature_map_count,
min_iteration_count);
}
else
{
convolution_normalize_weights_to_max_l2_norm_kernel<false><<<grid_size, block_size, smem_size, stream_id>>>(
*data[0],
*data[0],
max_l2_norm_squared,
incoming_weight_count_per_output_neuron,
output_feature_map_count,
min_iteration_count);
}
}

std::tr1::shared_ptr<weight_vector_bound_cuda> convolution_weight_vector_bound_cuda::create_specific() const
{
return std::tr1::shared_ptr<weight_vector_bound_cuda>(new convolution_weight_vector_bound_cuda());
}

void convolution_weight_vector_bound_cuda::weight_vector_bound_configured()
{
std::tr1::shared_ptr<const convolution_layer> layer_derived = std::tr1::dynamic_pointer_cast<const convolution_layer>(layer_schema);

incoming_weight_count_per_output_neuron = layer_derived->input_feature_map_count;
for(std::vector<unsigned int>::const_iterator it = layer_derived->window_sizes.begin(); it != layer_derived->window_sizes.end(); ++it)
incoming_weight_count_per_output_neuron *= *it;
output_feature_map_count = layer_derived->output_feature_map_count;
}

int convolution_weight_vector_bound_cuda::get_threadblock_size(int incoming_weight_count_per_output_neuron)
{
int threadblock_size;

if (incoming_weight_count_per_output_neuron < 256)
{
threadblock_size = (incoming_weight_count_per_output_neuron + 32 - 1) / 32 * 32;
}
else
{
int threadblock_count = (incoming_weight_count_per_output_neuron + 256 - 1) / 256;
threadblock_size = (incoming_weight_count_per_output_neuron + threadblock_count - 1) / threadblock_count;
threadblock_size = (threadblock_size + 32 - 1) / 32 * 32;
}

return threadblock_size;
}
}
}
54 changes: 54 additions & 0 deletions nnforge/cuda/convolution_weight_vector_bound_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright 2011-2013 Maxim Milakov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "weight_vector_bound_cuda.h"

namespace nnforge
{
namespace cuda
{
class convolution_weight_vector_bound_cuda: public weight_vector_bound_cuda
{
public:
convolution_weight_vector_bound_cuda();

virtual ~convolution_weight_vector_bound_cuda();

virtual const boost::uuids::uuid& get_uuid() const;

virtual void enqueue_normalize_weights(
cudaStream_t stream_id,
const weight_vector_bound& bound,
const std::vector<cuda_linear_buffer_device_smart_ptr>& data,
const std::vector<cuda_linear_buffer_device_smart_ptr>& additional_buffers,
unsigned int entry_count);

protected:
virtual std::tr1::shared_ptr<weight_vector_bound_cuda> create_specific() const;

// The method is called when configuration is finished
virtual void weight_vector_bound_configured();

int incoming_weight_count_per_output_neuron;
int output_feature_map_count;

private:
static int get_threadblock_size(int output_neuron_count);
};
}
}
5 changes: 5 additions & 0 deletions nnforge/cuda/cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
#include "soft_rectified_linear_layer_updater_schema.h"
#include "softmax_layer_updater_schema.h"

#include "weight_vector_bound_cuda_factory.h"
#include "convolution_weight_vector_bound_cuda.h"

namespace nnforge
{
namespace cuda
Expand Down Expand Up @@ -83,6 +86,8 @@ namespace nnforge
single_layer_updater_schema_factory::get_mutable_instance().register_layer_updater_schema(layer_updater_schema_smart_ptr(new rectified_linear_layer_updater_schema()));
single_layer_updater_schema_factory::get_mutable_instance().register_layer_updater_schema(layer_updater_schema_smart_ptr(new soft_rectified_linear_layer_updater_schema()));
single_layer_updater_schema_factory::get_mutable_instance().register_layer_updater_schema(layer_updater_schema_smart_ptr(new softmax_layer_updater_schema()));

single_weight_vector_bound_factory::get_mutable_instance().register_weight_vector_bound(weight_vector_bound_cuda_smart_ptr(new convolution_weight_vector_bound_cuda()));
}
}
}
36 changes: 35 additions & 1 deletion nnforge/cuda/network_updater_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "util_cuda.h"
#include "cuda_event.h"
#include "layer_updater_schema_factory.h"
#include "weight_vector_bound_cuda_factory.h"

#include <cuda_runtime.h>
#include <boost/format.hpp>
Expand Down Expand Up @@ -76,8 +77,9 @@ namespace nnforge
network_updater_cuda::network_updater_cuda(
network_schema_smart_ptr schema,
const std::map<unsigned int, float>& layer_to_dropout_rate_map,
const std::map<unsigned int, weight_vector_bound>& layer_to_weight_vector_bound_map,
cuda_running_configuration_const_smart_ptr cuda_config)
: network_updater(schema, layer_to_dropout_rate_map)
: network_updater(schema, layer_to_dropout_rate_map, layer_to_weight_vector_bound_map)
, cuda_config(cuda_config)
{
const const_layer_list& layer_list = *schema;
Expand All @@ -100,6 +102,15 @@ namespace nnforge
for(const_layer_list::const_iterator it = start_layer_nonempty_weights_iterator; it != layer_list.end(); ++it)
updater_schemas.push_back(single_layer_updater_schema_factory::get_const_instance().create_updater_schema_layer(*it, cuda_config));

for(std::map<unsigned int, weight_vector_bound>::const_iterator it = this->layer_to_weight_vector_bound_map.begin(); it != this->layer_to_weight_vector_bound_map.end(); ++it)
{
unsigned int layer_id = it->first;
if (layer_id < testing_layer_count)
throw neural_network_exception((boost::format("Weight vector bound is specified fo layer %1% while it is in testing part (consisting of %2% layers) of the updater") % layer_id % testing_layer_count).str());

weight_vector_bounds.insert(std::make_pair<unsigned int, weight_vector_bound_cuda_smart_ptr>(layer_id, single_weight_vector_bound_factory::get_const_instance().create_weight_vector_bound(layer_list[layer_id], cuda_config)));
}

setup_network_cuda();

for(const_layer_testing_schema_list::const_iterator it = testing_schemas.begin(); it != testing_schemas.end(); ++it)
Expand Down Expand Up @@ -229,6 +240,10 @@ namespace nnforge
output_errors = all_buffers.input_errors_buffer;
}

std::map<unsigned int, std::vector<cuda_linear_buffer_device_smart_ptr> > weight_vector_bound_buffers;
for(std::map<unsigned int, weight_vector_bound_cuda_smart_ptr>::const_iterator it = weight_vector_bounds.begin(); it != weight_vector_bounds.end(); ++it)
weight_vector_bound_buffers.insert(std::make_pair(it->first, it->second->allocate_additional_buffers(max_entry_count)));

cuda_linear_buffer_host_smart_ptr input_host_buf(new cuda_linear_buffer_host(input_neuron_count * max_entry_count * input_neuron_elem_size));
unsigned char * input = *input_host_buf;
cuda_linear_buffer_host_smart_ptr output_host_buf(new cuda_linear_buffer_host(output_neuron_count * max_entry_count * sizeof(float)));
Expand Down Expand Up @@ -411,6 +426,19 @@ namespace nnforge
input_and_all_buffers_pack_it->second.additional_buffers,
input_and_all_buffers_pack_it->second.dynamic_memobjects,
updater_entry_count);

weight_vector_bound_map::iterator bound_it = weight_vector_bounds.find(reverse_layer_id);
if (bound_it != weight_vector_bounds.end())
{
const weight_vector_bound& bound = layer_to_weight_vector_bound_map.find(reverse_layer_id)->second;
const std::vector<cuda_linear_buffer_device_smart_ptr>& additional_buffers = weight_vector_bound_buffers.find(reverse_layer_id)->second;
bound_it->second->enqueue_normalize_weights(
*command_stream,
bound,
*net_data_it,
additional_buffers,
updater_entry_count);
}
}
}

Expand Down Expand Up @@ -628,6 +656,9 @@ namespace nnforge

for(std::vector<layer_updater_cuda_smart_ptr>::const_iterator it = updater_list.begin(); it != updater_list.end(); ++it)
(*it)->update_buffer_configuration(buffer_configuration, updater_entry_count);

for(std::map<unsigned int, weight_vector_bound_cuda_smart_ptr>::const_iterator it = weight_vector_bounds.begin(); it != weight_vector_bounds.end(); ++it)
it->second->update_buffer_configuration(buffer_configuration, updater_entry_count);
}

unsigned int network_updater_cuda::get_max_batch_size() const
Expand All @@ -637,6 +668,9 @@ namespace nnforge
for(std::vector<layer_updater_cuda_smart_ptr>::const_iterator it = updater_list.begin(); it != updater_list.end(); ++it)
(*it)->update_buffer_configuration(buffer_configuration);

for(std::map<unsigned int, weight_vector_bound_cuda_smart_ptr>::const_iterator it = weight_vector_bounds.begin(); it != weight_vector_bounds.end(); ++it)
it->second->update_buffer_configuration(buffer_configuration);

return cuda_config->get_max_entry_count(buffer_configuration, 0.5F);
}
}
Expand Down
4 changes: 4 additions & 0 deletions nnforge/cuda/network_updater_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "cuda_stream.h"
#include "layer_testing_schema.h"
#include "layer_updater_schema.h"
#include "weight_vector_bound_cuda.h"

namespace nnforge
{
Expand All @@ -33,6 +34,7 @@ namespace nnforge
network_updater_cuda(
network_schema_smart_ptr schema,
const std::map<unsigned int, float>& layer_to_dropout_rate_map,
const std::map<unsigned int, weight_vector_bound>& layer_to_weight_vector_bound_map,
cuda_running_configuration_const_smart_ptr cuda_config);

virtual ~network_updater_cuda();
Expand Down Expand Up @@ -89,6 +91,8 @@ namespace nnforge
std::vector<std::vector<const_cuda_linear_buffer_device_smart_ptr> > updater_schema_data;
std::vector<layer_updater_cuda_smart_ptr> updater_list;

weight_vector_bound_map weight_vector_bounds;

static unsigned int max_entry_count_in_single_batch;
};
}
Expand Down
9 changes: 7 additions & 2 deletions nnforge/cuda/network_updater_cuda_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ namespace nnforge

network_updater_smart_ptr network_updater_cuda_factory::create(
network_schema_smart_ptr schema,
const std::map<unsigned int, float>& layer_to_dropout_rate_map) const
const std::map<unsigned int, float>& layer_to_dropout_rate_map,
const std::map<unsigned int, weight_vector_bound>& layer_to_weight_vector_bound_map) const
{
return network_updater_smart_ptr(new network_updater_cuda(schema, layer_to_dropout_rate_map, cuda_config));
return network_updater_smart_ptr(new network_updater_cuda(
schema,
layer_to_dropout_rate_map,
layer_to_weight_vector_bound_map,
cuda_config));
}
}
}
3 changes: 2 additions & 1 deletion nnforge/cuda/network_updater_cuda_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace nnforge

virtual network_updater_smart_ptr create(
network_schema_smart_ptr schema,
const std::map<unsigned int, float>& layer_to_dropout_rate_map) const;
const std::map<unsigned int, float>& layer_to_dropout_rate_map,
const std::map<unsigned int, weight_vector_bound>& layer_to_weight_vector_bound_map) const;

protected:
cuda_running_configuration_const_smart_ptr cuda_config;
Expand Down
Loading

0 comments on commit abf9638

Please sign in to comment.