Skip to content

Commit

Permalink
Explicit device tensors (#1081)
Browse files Browse the repository at this point in the history
  • Loading branch information
kpot authored Dec 20, 2023
1 parent 7c6f017 commit 1fd07fc
Show file tree
Hide file tree
Showing 235 changed files with 2,076 additions and 1,616 deletions.
4 changes: 2 additions & 2 deletions backend-comparison/benches/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ impl<B: Backend, const D: usize> Benchmark for BinaryBenchmark<B, D> {
}

fn prepare(&self) -> Self::Args {
let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device);
let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device);
let lhs = Tensor::random(self.shape.clone(), Distribution::Default, &self.device);
let rhs = Tensor::random(self.shape.clone(), Distribution::Default, &self.device);

(lhs, rhs)
}
Expand Down
2 changes: 1 addition & 1 deletion backend-comparison/benches/custom_gelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {
}

fn prepare(&self) -> Self::Args {
Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device)
Tensor::random(self.shape.clone(), Distribution::Default, &self.device)
}

fn sync(&self) {
Expand Down
4 changes: 2 additions & 2 deletions backend-comparison/benches/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl<B: Backend, const D: usize> Benchmark for ToDataBenchmark<B, D> {
}

fn prepare(&self) -> Self::Args {
Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device)
Tensor::random(self.shape.clone(), Distribution::Default, &self.device)
}

fn sync(&self) {
Expand All @@ -48,7 +48,7 @@ impl<B: Backend, const D: usize> Benchmark for FromDataBenchmark<B, D> {

fn execute(&self, (data, device): Self::Args) {
for _ in 0..self.num_repeats {
let _data = Tensor::<B, D>::from_data_device(data.clone(), &device);
let _data = Tensor::<B, D>::from_data(data.clone(), &device);
}
}

Expand Down
6 changes: 2 additions & 4 deletions backend-comparison/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
}

fn prepare(&self) -> Self::Args {
let lhs =
Tensor::random_device(self.shape_lhs.clone(), Distribution::Default, &self.device);
let rhs =
Tensor::random_device(self.shape_rhs.clone(), Distribution::Default, &self.device);
let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default, &self.device);
let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default, &self.device);

(lhs, rhs)
}
Expand Down
2 changes: 1 addition & 1 deletion backend-comparison/benches/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl<B: Backend, const D: usize> Benchmark for UnaryBenchmark<B, D> {
}

fn prepare(&self) -> Self::Args {
Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device)
Tensor::random(self.shape.clone(), Distribution::Default, &self.device)
}

fn sync(&self) {
Expand Down
4 changes: 2 additions & 2 deletions burn-autodiff/src/tests/abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ mod tests {
let data_1 = Data::<f32, 2>::from([[0.0, -1.0], [3.0, 4.0]]);
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, -10.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad();
let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());
let tensor_4 = tensor_3.matmul(tensor_2.clone());
Expand Down
6 changes: 3 additions & 3 deletions burn-autodiff/src/tests/adaptive_avgpool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod tests {
output_size: 3,
};

test.assert_output(TestTensor::from_floats([[
test.assert_output(TestTensor::from_floats_devauto([[
[0.5000, 0.8333, 0.3333, 0.8333, 0.5000],
[0.5000, 0.8333, 0.3333, 0.8333, 0.5000],
]]));
Expand All @@ -29,8 +29,8 @@ mod tests {
impl AdaptiveAvgPool1dTestCase {
fn assert_output(self, x_grad: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
let x = TestAutodiffTensor::from_data_devauto(
TestTensorInt::arange_devauto(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
Expand Down
6 changes: 3 additions & 3 deletions burn-autodiff/src/tests/adaptive_avgpool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod tests {
output_size_2: 2,
};

test.assert_output(TestTensor::from_floats([[
test.assert_output(TestTensor::from_floats_devauto([[
[
[0.2500, 0.5000, 0.2500],
[0.4167, 0.8333, 0.4167],
Expand Down Expand Up @@ -45,8 +45,8 @@ mod tests {
impl AdaptiveAvgPool2dTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
let x = TestAutodiffTensor::from_data_devauto(
TestTensorInt::arange_devauto(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
Expand Down
12 changes: 6 additions & 6 deletions burn-autodiff/src/tests/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ mod tests {

#[test]
fn should_diff_add() {
let tensor_1 = TestAutodiffTensor::from_floats([2.0, 5.0]).require_grad();
let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0]).require_grad();
let tensor_1 = TestAutodiffTensor::from_floats_devauto([2.0, 5.0]).require_grad();
let tensor_2 = TestAutodiffTensor::from_floats_devauto([4.0, 1.0]).require_grad();

let tensor_3 = tensor_1.clone() + tensor_2.clone();
let grads = tensor_3.backward();
Expand All @@ -23,7 +23,7 @@ mod tests {
fn should_diff_add_scalar() {
let data = Data::from([2.0, 10.0]);

let tensor = TestAutodiffTensor::from_data(data).require_grad();
let tensor = TestAutodiffTensor::from_data_devauto(data).require_grad();
let tensor_out = tensor.clone().add_scalar(5.0);
let grads = tensor_out.backward();

Expand All @@ -39,9 +39,9 @@ mod tests {
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3: Data<f32, 2> = Data::from([[2.0, 2.0], [2.0, 2.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3).require_grad();
let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();
let tensor_3 = TestAutodiffTensor::from_data_devauto(data_3).require_grad();

let tensor_4 = tensor_1.clone().add(tensor_2.clone());
let tensor_5 = tensor_4
Expand Down
20 changes: 10 additions & 10 deletions burn-autodiff/src/tests/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ mod tests {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad();
let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze());
Expand All @@ -31,8 +31,8 @@ mod tests {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad();
let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze());
Expand All @@ -54,8 +54,8 @@ mod tests {
let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad();
let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.clone().sum_dim(1);
Expand All @@ -78,8 +78,8 @@ mod tests {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad();
let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze());
Expand All @@ -101,8 +101,8 @@ mod tests {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [-2.0, -3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, -7.0], [2.0, 3.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad();
let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze());
Expand Down
10 changes: 5 additions & 5 deletions burn-autodiff/src/tests/avgpool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod tests {
count_include_pad: true,
};

test.assert_output(TestTensor::from_floats([[[
test.assert_output(TestTensor::from_floats_devauto([[[
0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333,
]]]));
}
Expand All @@ -33,7 +33,7 @@ mod tests {
count_include_pad: true,
};

test.assert_output(TestTensor::from_floats([[
test.assert_output(TestTensor::from_floats_devauto([[
[0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333],
[0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.3333],
]]));
Expand All @@ -51,7 +51,7 @@ mod tests {
count_include_pad: false,
};

test.assert_output(TestTensor::from_floats([[
test.assert_output(TestTensor::from_floats_devauto([[
[0.5000, 0.8333, 0.3333, 0.6667, 0.3333, 0.3333],
[0.5000, 0.8333, 0.3333, 0.6667, 0.3333, 0.3333],
]]));
Expand All @@ -70,8 +70,8 @@ mod tests {
impl AvgPool1dTestCase {
fn assert_output(self, x_grad: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
let x = TestAutodiffTensor::from_data_devauto(
TestTensorInt::arange_devauto(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
Expand Down
10 changes: 5 additions & 5 deletions burn-autodiff/src/tests/avgpool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod tests {
count_include_pad: true,
};

test.assert_output(TestTensor::from_floats([[[
test.assert_output(TestTensor::from_floats_devauto([[[
[0.1111, 0.2222, 0.3333, 0.3333, 0.2222, 0.1111],
[0.2222, 0.4444, 0.6667, 0.6667, 0.4444, 0.2222],
[0.3333, 0.6667, 1.0000, 1.0000, 0.6667, 0.3333],
Expand All @@ -46,7 +46,7 @@ mod tests {
count_include_pad: true,
};

test.assert_output(TestTensor::from_floats([[[
test.assert_output(TestTensor::from_floats_devauto([[[
[0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
Expand All @@ -70,7 +70,7 @@ mod tests {
count_include_pad: false,
};

test.assert_output(TestTensor::from_floats([[[
test.assert_output(TestTensor::from_floats_devauto([[[
[0.6250, 0.6250, 0.4167, 0.4167, 0.6250, 0.6250],
[0.8750, 0.8750, 0.5833, 0.5833, 0.8750, 0.8750],
[0.8750, 0.8750, 0.5833, 0.5833, 0.8750, 0.8750],
Expand All @@ -95,8 +95,8 @@ mod tests {
impl AvgPool2dTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
let x = TestAutodiffTensor::from_data_devauto(
TestTensorInt::arange_devauto(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
Expand Down
6 changes: 3 additions & 3 deletions burn-autodiff/src/tests/backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ mod tests {
[[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]],
[[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]],
]);
let weights = Tensor::<TestAutodiffBackend, 2>::from_data(weights).require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(indices);
let x = Tensor::<TestAutodiffBackend, 3>::from_data(x).require_grad();
let weights = Tensor::<TestAutodiffBackend, 2>::from_data_devauto(weights).require_grad();
let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data_devauto(indices);
let x = Tensor::<TestAutodiffBackend, 3>::from_data_devauto(x).require_grad();

let output = embedding(weights.clone(), indices);
let output = output.matmul(x);
Expand Down
4 changes: 2 additions & 2 deletions burn-autodiff/src/tests/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ mod tests {
where
F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>,
{
let w = TestAutodiffTensor::zeros([16, 5, 5]).require_grad();
let x = TestAutodiffTensor::zeros([4, 5, 5]).require_grad();
let w = TestAutodiffTensor::zeros_devauto([16, 5, 5]).require_grad();
let x = TestAutodiffTensor::zeros_devauto([4, 5, 5]).require_grad();

// Slice isn't a broadcastable operation, so it will fail when the previous backward pass
// of an operation that support broadcast doesn't support it during the backward pass.
Expand Down
13 changes: 8 additions & 5 deletions burn-autodiff/src/tests/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ mod tests {

#[test]
fn should_diff_cat() {
let tensor_1 = TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad();
let tensor_2 = TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0]]).require_grad();
let tensor_1 =
TestAutodiffTensor::from_data_devauto([[2.0, -1.0], [5.0, 2.0]]).require_grad();
let tensor_2 =
TestAutodiffTensor::from_data_devauto([[5.0, 4.0], [-1.0, 4.0]]).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let grads = tensor_3.backward();
Expand Down Expand Up @@ -57,9 +59,10 @@ mod tests {

#[test]
fn should_diff_cat_more_than_1_dim() {
let tensor_1 = TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad();
let tensor_2 =
TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]]).require_grad();
let tensor_1 =
TestAutodiffTensor::from_data_devauto([[2.0, -1.0], [5.0, 2.0]]).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]])
.require_grad();

// Concat a tensor [2, 2] with another tensor [3, 2] along dim 0.
// The resulting tensor should be [5, 2]
Expand Down
12 changes: 6 additions & 6 deletions burn-autodiff/src/tests/complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ mod tests {
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad();
let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
Expand All @@ -35,8 +35,8 @@ mod tests {
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad();
let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
Expand All @@ -59,8 +59,8 @@ mod tests {
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2).require_grad();
let tensor_1 = TestAutodiffTensor::from_data_devauto(data_1).require_grad();
let tensor_2 = TestAutodiffTensor::from_data_devauto(data_2).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_3.matmul(tensor_1.clone());
Expand Down
Loading

0 comments on commit 1fd07fc

Please sign in to comment.