From 91777b22901d9f4335a92307a7b223e5eb19d0a1 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Wed, 10 Jan 2024 15:52:18 +0800 Subject: [PATCH] update ut --- test/auto_parallel/test_shard_tensor_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/auto_parallel/test_shard_tensor_api.py b/test/auto_parallel/test_shard_tensor_api.py index 14b0529b96bd5e..5efa0c25031f18 100644 --- a/test/auto_parallel/test_shard_tensor_api.py +++ b/test/auto_parallel/test_shard_tensor_api.py @@ -93,9 +93,9 @@ def test_dynamic_mode_property_change(self): self.assertEqual(d_tensor.process_mesh, self.mesh) def test_stop_gradient(self): - x = paddle.ones([10, 10]) + x = paddle.ones([4, 1024, 512]) x.stop_gradient = False - x = dist.shard_tensor(x, self.mesh, [Shard(0)]) + x = dist.shard_tensor(x, self.mesh, [Shard(0), Replicate()]) assert not x.stop_gradient