diff --git a/src/main/scala/lantern/ad_lms_vector.scala b/src/main/scala/lantern/ad_lms_vector.scala index 6f4d3d21..cc5a5f79 100644 --- a/src/main/scala/lantern/ad_lms_vector.scala +++ b/src/main/scala/lantern/ad_lms_vector.scala @@ -205,7 +205,6 @@ trait TensorExp extends Dsl with Diff { trait BackendNative extends Backend { override def dot(x: Tensor, y: Tensor): Tensor = { // TODO: (Fei Wang): only support 2D dot 1D and 1D dot 1D - assert (x.dims.size <= 2 && y.dims.size == 1, "TODO: (Fei Wang): only support 2D dot 1D and 1D dot 1D") val off = var_new(0) val up = if (x.rank > 1) x.shape(0) else 1 val res = NewArray[Float](up) diff --git a/src/main/scala/lantern/nnModule.scala b/src/main/scala/lantern/nnModule.scala index a1bf3dfb..28144dee 100644 --- a/src/main/scala/lantern/nnModule.scala +++ b/src/main/scala/lantern/nnModule.scala @@ -41,7 +41,7 @@ trait NNModule extends TensorExp { } case class Linear1D(val inSize: Int, val outSize: Int, val name: String = "linear1d") extends Module { - val scale: Float = 1.0f / inSize + val scale: Float = 1.0f / sqrt(inSize).toFloat val weight = regTensorWithName("w")(TensorR(Tensor.rand(scale, outSize, inSize))) val bias = regTensorWithName("b")(TensorR(Tensor.zeros(outSize))) def apply(in: TensorR): TensorR @diff = weight.dot(in) + bias @@ -50,7 +50,7 @@ trait NNModule extends TensorExp { case class Conv2D(val inChannel: Int, val outChannel: Int, val kernelSize: NSeq[Int], val stride: NSeq[Int] = NSeq(1, 1), val pad: Int = 0, val name: String = "conv2d") extends Module { assert(kernelSize.size == 2, "kernel_size should be Seq[Int] of size 2") assert(stride.size == 2, "stride should be Seq[Int] of size 2") - val scale: Float = 1.0f / (inChannel * kernelSize.head * kernelSize.last) + val scale: Float = 1.0f / sqrt(inChannel * kernelSize.head * kernelSize.last).toFloat val kernel = regTensorWithName("k")(TensorR(Tensor.rand(scale, outChannel, inChannel, kernelSize.head, kernelSize.last))) val bias = regTensorWithName("b")(TensorR(Tensor.zeros(outChannel))) def apply(in: TensorR): TensorR @diff = in.convBBP(kernel, bias, stride, NSeq(pad, pad, pad, pad)) @@ -68,7 +68,7 @@ trait NNModule extends TensorExp { if (descent) tr.x -= tr.d * learning_rate else - tr.x -= tr.d * learning_rate + tr.x += tr.d * learning_rate tr.clear_grad() } } diff --git a/src/test/scala/lantern/mnistCNN.scala b/src/test/scala/lantern/mnistCNN.scala index 6fa7acc3..060136ae 100644 --- a/src/test/scala/lantern/mnistCNN.scala +++ b/src/test/scala/lantern/mnistCNN.scala @@ -263,7 +263,7 @@ class MnistCNN extends FunSuite { } } val net = MNIST("model") - val opt = Adagrad(net, learning_rate = 0.0005f) + val opt = SGD(net, learning_rate = 0.0005f, gradClip = 1000.0f) def lossFun(input: TensorR, target: Rep[Int]) = { (dummy: TensorR) => val res = net(input).logSoftmax()