Skip to content

Commit

Permalink
Preliminary work to start implementing the Flatten layer (#6), Conv2D (
Browse files Browse the repository at this point in the history
…#5), and Dropout (#14).
  • Loading branch information
cesarsouza committed Nov 23, 2017
1 parent 51c220c commit 4be2d2f
Show file tree
Hide file tree
Showing 10 changed files with 724 additions and 43 deletions.
48 changes: 42 additions & 6 deletions Backends/CNTK.CPU/CNTKBackend.cs
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,11 @@ public Tensor add(Tensor a, Tensor b)
return Out(new Variable(In(a).function) + new Variable(In(b).function));
}

public Tensor bias_add(Tensor output, Tensor bias, string name = null)
public Tensor bias_add(Tensor output, Tensor bias, DataFormatType? data_format = null, string name = null)
{
if (data_format != null)
throw new NotImplementedException();

using (this.name_scope("bias_add"))
{
CNTKTensor _x = In(output);
Expand Down Expand Up @@ -474,6 +477,21 @@ public Tensor transpose(Tensor tensor)
return Out(C.Transpose(In(tensor)));
}

/// <summary>
/// Turn a nD tensor into a 2D tensor with same 0th dimension. In other words, it flattens each data samples of a batch.
/// </summary>
///
public Tensor batch_flatten(Tensor x)
{
// https://github.com/fchollet/keras/blob/f65a56fb65062c8d14d215c9f4b1015b97cc5bf3/keras/backend/cntk_backend.py#L1460
// cntk's batch axis is not in shape,
// so just flatten all the dim in x.shape
int dim = Matrix.Product(x.shape.Select(s => s.Value).ToArray());
x = Out(C.Reshape(In(x), NDShape.CreateNDShape(new[] { -1 })));
x._keras_shape = new int?[] { null, dim };
return x;
}

public object eval(Tensor tensor)
{
log(new { tensor });
Expand Down Expand Up @@ -504,7 +522,7 @@ public Tensor clip(Tensor norms, double minval, double maxval)
throw new NotImplementedException();
}

public Tensor random_uniform(int?[] shape, double minval = 0, double maxval = 1, DataType? dtype = null, int? seed = null, string name = null)
public Tensor random_uniform(int[] shape, double minval = 0, double maxval = 1, DataType? dtype = null, int? seed = null, string name = null)
{
if (dtype == null)
dtype = floatx();
Expand Down Expand Up @@ -694,7 +712,7 @@ public Tensor in_train_phase(Func<Tensor> x, Func<Tensor> alt, bool? training)
return Out(In(input_tensor).function.Output.DataType);
}

public Tensor constant<T>(T value, int?[] shape = null, KerasSharp.DataType? dtype = null, string name = null)
public Tensor constant<T>(T value, int[] shape = null, KerasSharp.DataType? dtype = null, string name = null)
{
log(new { value, shape, dtype, name });

Expand All @@ -705,7 +723,7 @@ public Tensor constant<T>(T value, int?[] shape = null, KerasSharp.DataType? dty
return Out(_const, shape);
}

public Constant InGeneric<T>(T value, int?[] shape = null, KerasSharp.DataType? dtype = null, string name = null)
public Constant InGeneric<T>(T value, int[] shape = null, KerasSharp.DataType? dtype = null, string name = null)
{
if (dtype == null)
dtype = floatx();
Expand All @@ -726,7 +744,7 @@ public Constant InGeneric<T>(T value, int?[] shape = null, KerasSharp.DataType?
}
else
{
_shape = shape.Select(x => x.Value).ToArray();
_shape = shape;
}

Constant c = _constant(value, _shape, _dtype, name);
Expand Down Expand Up @@ -979,6 +997,20 @@ public Tensor reshape(Tensor x, int[] shape)
}


public Tensor conv1d(Tensor inputs, Tensor kernel, int strides, PaddingType padding, DataFormatType? data_format = null, int dilation_rate = 1, string name = null)
{
throw new NotImplementedException();
}

public Tensor conv2d(Tensor inputs, Tensor kernel, int[] strides, PaddingType padding, DataFormatType? data_format = null, int[] dilation_rate = null, string name = null)
{
throw new NotImplementedException();
}

public Tensor conv3d(Tensor inputs, Tensor kernel, int[] strides, PaddingType padding, DataFormatType? data_format = null, int[] dilation_rate = null, string name = null)
{
throw new NotImplementedException();
}



Expand Down Expand Up @@ -1084,6 +1116,11 @@ public NDShape InShape(int[] shape)
return s;
}

public Tensor Out(CNTK.Function function, int[] keras_shape)
{
return Out(function, keras_shape.Select(x => (int?)x).ToArray());
}

public Tensor Out(CNTK.Function function, int?[] keras_shape = null)
{
var t = new CNTKTensor(this)
Expand Down Expand Up @@ -1207,7 +1244,6 @@ public void Dispose()
// TODO: uncomment the following line if the finalizer is overridden above.
// GC.SuppressFinalize(this);
}

#endregion
}
}
179 changes: 157 additions & 22 deletions Backends/TensorFlow/TensorFlowBackend.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class TensorFlowBackend : BackendBase, IBackend
// This dictionary holds a mapping {graph: learning_phase}.
// A learning phase is a bool tensor used to run Keras models in
// either train mode (learning_phase == 1) or test mode (learning_phase == 0).
private Dictionary<TFGraph, TFOutput> _GRAPH_LEARNING_PHASES = new Dictionary<TFGraph, TFOutput>();
private Dictionary<TFGraph, object> _GRAPH_LEARNING_PHASES = new Dictionary<TFGraph, object>();

// This dictionary holds a mapping {graph: UID_DICT}.
// each UID_DICT is a dictionary mapping name prefixes to a current index,
Expand Down Expand Up @@ -115,7 +115,7 @@ public void clear_session()
//
reset_uids();
TFOutput phase = tf.Placeholder(dtype: TFDataType.Bool, operName: "keras_learning_phase");
_GRAPH_LEARNING_PHASES = new Dictionary<TFGraph, TFOutput>();
_GRAPH_LEARNING_PHASES = new Dictionary<TFGraph, object>();
_GRAPH_LEARNING_PHASES[tf] = phase;
}

Expand Down Expand Up @@ -254,7 +254,7 @@ public Tensor clip_norm(Tensor g, double clipnorm, Tensor norm)
throw new NotImplementedException();
}

public Tensor constant<T>(T value, int?[] shape = null, DataType? dtype = null, string name = null)
public Tensor constant<T>(T value, int[] shape = null, DataType? dtype = null, string name = null)
{
if (dtype == null)
dtype = floatx();
Expand All @@ -266,11 +266,11 @@ public Tensor constant<T>(T value, int?[] shape = null, DataType? dtype = null,
if (arr != null)
_shape = arr.GetLength();
else _shape = new int[0];
shape = _shape.Select(x => (int?)x).ToArray();
shape = _shape;
}
else
{
_shape = shape.Select(x => x.Value).ToArray();
_shape = shape;
}

TFOutput o;
Expand Down Expand Up @@ -444,7 +444,9 @@ public Tensor in_train_phase(Func<Tensor> x, Func<Tensor> alt, bool? training)

if (training == null)
{
training = (bool)learning_phase();
var t = learning_phase();
if (t is bool)
training = (bool)t;
uses_learning_phase = true;
}
else
Expand All @@ -463,14 +465,21 @@ public Tensor in_train_phase(Func<Tensor> x, Func<Tensor> alt, bool? training)
else
{
//else: assume learning phase is a placeholder tensor.
throw new NotImplementedException();
}

// Tensor xx = @switch(training, x, alt);
Tensor xx = @switch((Tensor)learning_phase(), x, alt);

if (uses_learning_phase)
x()._uses_learning_phase = true;
return x();
if (uses_learning_phase)
xx._uses_learning_phase = true;
return xx;
}
}

/// <summary>
/// Selects `x` in test phase, and `alt` otherwise. Note that `alt` should have the* same shape* as `x`.
/// </summary>
public Tensor in_test_phase(Func<Tensor> x, Func<Tensor> alt, bool? training = null)
{
return in_train_phase(alt, x, training: training);
}

/// <summary>
Expand All @@ -491,12 +500,10 @@ public Tensor @switch(Tensor condition, Func<Tensor> then_expression, Func<Tenso
if (_condition.dtype != TFDataType.Bool)
condition = Out(tf.Cast(_condition, TFDataType.Bool));

throw new NotImplementedException();

//TFOutput x = tf.cond(condition,
// () => then_expression().output,
// () => else_expression().output);
//return tensor(x);
TFOutput x = tf.Cond(In(condition),
() => In(then_expression()),
() => In(else_expression()));
return Out(x);
}

public bool is_sparse(Tensor tensor)
Expand Down Expand Up @@ -530,7 +537,15 @@ public object learning_phase()
_GRAPH_LEARNING_PHASES[graph] = phase;
}

return Out(_GRAPH_LEARNING_PHASES[graph]);
return _GRAPH_LEARNING_PHASES[graph];
}

/// <summary>
/// Sets the learning phase to a fixed value.
/// </summary>
public void set_learning_phase(bool value)
{
_GRAPH_LEARNING_PHASES[tf] = value;
}

public Tensor max(Tensor x, int v, object p)
Expand All @@ -553,6 +568,17 @@ public Tensor maximum(double v, Tensor tensor)
throw new NotImplementedException();
}

/// <summary>
/// Turn a nD tensor into a 2D tensor with same 0th dimension. In other words, it flattens each data samples of a batch.
/// </summary>
///
public Tensor batch_flatten(Tensor x)
{
var _x = In(x);
TFOutput shape = tf.Shape(_x);
TFOutput dim = tf.Prod(tf.Slice(shape, tf.Const(1), tf.Rank(shape)), reduction_indices: tf.ReduceDims(shape, null));
return Out(tf.Reshape(In(x), tf.Stack(new TFOutput[] { tf.Const(-1), dim } )));
}


public TFOutput _normalize_axis(int[] axis, int? ndim)
Expand Down Expand Up @@ -664,9 +690,25 @@ public Tensor add(Tensor a, Tensor b)
return Out(tf.Add(In(a).output, In(b).output));
}

public Tensor bias_add(Tensor a, Tensor b, string name = null)
public Tensor bias_add(Tensor a, Tensor b, DataFormatType? data_format = null, string name = null)
{
return Out(tf.BiasAdd(In(a), In(b), data_format: In(data_format), operName: name));
}

private string In(DataFormatType? data_format)
{
return Out(tf.BiasAdd(In(a), In(b), operName: name));
if (data_format == null)
return null;

switch (data_format.Value)
{
case DataFormatType.ChannelsFirst:
return "channels_first";
case DataFormatType.ChannelsLast:
return "channels_last";
default:
throw new Exception();
}
}

public Tensor add<T>(T a, Tensor b)
Expand Down Expand Up @@ -766,7 +808,7 @@ public Tensor placeholder(int?[] shape = null, int? ndim = null, DataType? dtype
///
/// <returns>A tensor.</returns>
///
public Tensor random_uniform(int?[] shape, double minval = 0.0, double maxval = 1.0, DataType? dtype = null, int? seed = null, string name = null)
public Tensor random_uniform(int[] shape, double minval = 0.0, double maxval = 1.0, DataType? dtype = null, int? seed = null, string name = null)
{
if (dtype == null)
dtype = floatx();
Expand Down Expand Up @@ -989,6 +1031,11 @@ public Tensor transpose(Tensor tensor)
return Out(tf.Transpose(In(tensor).output));
}

public Tensor transpose(Tensor tensor, int[] perm)
{
return Out(tf.Transpose(In(tensor).output, _constant(perm)));
}


public object eval(Tensor tensor)
{
Expand Down Expand Up @@ -1021,6 +1068,94 @@ public object eval(TFOutput output)



public Tensor conv1d(Tensor inputs, Tensor kernel, int strides, PaddingType padding, DataFormatType? data_format = null, int dilation_rate = 1, string name = null)
{
throw new NotImplementedException();
}

public Tensor conv2d(Tensor inputs, Tensor kernel, int[] strides, PaddingType padding, DataFormatType? data_format = null, int[] dilation_rate = null, string name = null)
{
// https://github.com/fchollet/keras/blob/f65a56fb65062c8d14d215c9f4b1015b97cc5bf3/keras/backend/tensorflow_backend.py#L3102
if (data_format == null)
data_format = image_data_format();

if (!dilation_rate.IsEqual(new[] { 1, 1 }))
throw new NotImplementedException();

TFOutput x = In(inputs).output;
TFOutput _kernel = In(kernel).output;

// With 4d inputs, tf.nn.convolution only supports
// data_format NHWC, so we transpose the inputs
// in case we are in data_format channels_first.
x = _preprocess_conv2d_input(x, data_format.Value);
string _padding = _preprocess_padding(padding);
x = tf.Conv2D(
input: x,
filter: _kernel,
//dilation_rate: dilation_rate,
strides: strides.Select(i => (long)i).ToArray(),
padding: _padding,
data_format: "NHWC");
return Out(_postprocess_conv2d_output(x, data_format.Value));
}

/// <summary>
/// Transpose and cast the output from conv2d if needed.
/// </summary>
private TFOutput _postprocess_conv2d_output(TFOutput x, DataFormatType data_format)
{
if (data_format == DataFormatType.ChannelsFirst)
x = tf.Transpose(x, _constant(new[] { 0, 3, 1, 2 }));

if (floatx() == DataType.Double)
x = tf.Cast(x, TFDataType.Double);
return x;
}

/// <summary>
/// Convert keras' padding to tensorflow's padding.
/// </summary>
///
public string _preprocess_padding(PaddingType padding)
{
switch (padding)
{
case PaddingType.Same:
return "SAME";
case PaddingType.Valid:
return "VALID";
}

throw new ArgumentException($"Invalid padding: {padding}");
}

/// <summary>
/// Transpose and cast the input before the conv2d.
/// </summary>
private TFOutput _preprocess_conv2d_input(TFOutput x, DataFormatType data_format)
{
if (x.OutputType == TFDataType.Double)
x = tf.Cast(x, TFDataType.Float);

if (data_format == DataFormatType.ChannelsFirst)
{
// TF uses the last dimension as channel dimension,
// instead of the 2nd one.
// TH input shape: (samples, input_depth, rows, cols)
// TF input shape: (samples, rows, cols, input_depth)
x = tf.Transpose(x, _constant(new[] { 0, 2, 3, 1 }));
}

return x;
}

public Tensor conv3d(Tensor inputs, Tensor kernel, int[] strides, PaddingType padding, DataFormatType? data_format = null, int[] dilation_rate = null, string name = null)
{
throw new NotImplementedException();
}



/// <summary>
/// Instantiates an all-zeros variable and returns it.
Expand Down
Loading

0 comments on commit 4be2d2f

Please sign in to comment.