From 1e5f0066f244ca0065df6a509867df30cb880f22 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 2 Oct 2022 23:15:23 -0700 Subject: [PATCH 1/7] Optimize targeted_left_multiply for onehot source --- cirq-core/cirq/linalg/transformations.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index dbe57d53de7..aa486621c3e 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -141,6 +141,21 @@ def targeted_left_multiply( k = len(target_axes) d = len(right_target.shape) + nonzeros = np.flatnonzero(left_matrix) + if len(nonzeros) == 1: + index = np.unravel_index(nonzeros[0], left_matrix.shape) + if out is None: + out = np.zeros_like(right_target) + else: + out[...] = 0 + source_slices = [slice(None)] * d + target_slices = [slice(None)] * d + for i in range(k): + source_slices[target_axes[i]] = slice(index[k + i], index[k + i] + 1) + target_slices[target_axes[i]] = slice(index[i], index[i] + 1) + out[target_slices] = right_target[source_slices] * left_matrix[index] + return out + work_indices = tuple(range(k)) data_indices = tuple(range(k, k + d)) used_data_indices = tuple(data_indices[q] for q in target_axes) From d6f51f375728e2c1d7d8d8dffe48f09e23c4cbe6 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 2 Oct 2022 23:35:46 -0700 Subject: [PATCH 2/7] mypy --- cirq-core/cirq/linalg/transformations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index aa486621c3e..628c97b68d5 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -153,7 +153,7 @@ def targeted_left_multiply( for i in range(k): source_slices[target_axes[i]] = slice(index[k + i], index[k + i] + 1) target_slices[target_axes[i]] = slice(index[i], index[i] + 1) - out[target_slices] = right_target[source_slices] * left_matrix[index] + out[target_slices] = right_target[source_slices] * left_matrix[index] # type: ignore return out work_indices = tuple(range(k)) From d6e91df4a59e5d8f45fe27c04c6d03696f54062c Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 2 Oct 2022 23:46:03 -0700 Subject: [PATCH 3/7] fix test --- cirq-core/cirq/linalg/transformations_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index b92ff5faa13..bbd2ec5ed11 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -192,7 +192,7 @@ def test_targeted_conjugate_simple(): np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16]), + 13, 14, 15, 16], dtype=np.complex), (2,) * 4 ) expected = np.reshape( From f834d31e839aa9a09a38b5629f8eb64458ec1778 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 2 Oct 2022 23:56:28 -0700 Subject: [PATCH 4/7] fix test --- cirq-core/cirq/linalg/transformations.py | 3 ++- cirq-core/cirq/linalg/transformations_test.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 628c97b68d5..5f4d5d9026e 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -153,7 +153,8 @@ def targeted_left_multiply( for i in range(k): source_slices[target_axes[i]] = slice(index[k + i], index[k + i] + 1) target_slices[target_axes[i]] = slice(index[i], index[i] + 1) - out[target_slices] = right_target[source_slices] * left_matrix[index] # type: ignore + proj = right_target[tuple(source_slices)] # type: ignore + out[tuple(target_slices)] = proj * left_matrix[index] # type: ignore return out work_indices = tuple(range(k)) diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index bbd2ec5ed11..5c5f1af1a5c 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -192,7 +192,7 @@ def test_targeted_conjugate_simple(): np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16], dtype=np.complex), + 13, 14, 15, 16], dtype=complex), (2,) * 4 ) expected = np.reshape( From 3a485bbea34c9c4275a21eaf10bbd197e1bf8b4a Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 3 Oct 2022 00:07:19 -0700 Subject: [PATCH 5/7] mypy --- cirq-core/cirq/linalg/transformations.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 5f4d5d9026e..b83348e286e 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -153,8 +153,7 @@ def targeted_left_multiply( for i in range(k): source_slices[target_axes[i]] = slice(index[k + i], index[k + i] + 1) target_slices[target_axes[i]] = slice(index[i], index[i] + 1) - proj = right_target[tuple(source_slices)] # type: ignore - out[tuple(target_slices)] = proj * left_matrix[index] # type: ignore + out[tuple(target_slices)] = right_target[tuple(source_slices)] * left_matrix[index] return out work_indices = tuple(range(k)) From db507ebffe5ac7ed30aecd42557132c193114515 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 3 Oct 2022 01:34:02 -0700 Subject: [PATCH 6/7] simplify expression --- cirq-core/cirq/linalg/transformations.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index b83348e286e..e0d41851430 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -14,7 +14,7 @@ """Utility methods for transforming matrices or vectors.""" -from typing import Tuple, Optional, Sequence, List, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np @@ -148,11 +148,11 @@ def targeted_left_multiply( out = np.zeros_like(right_target) else: out[...] = 0 - source_slices = [slice(None)] * d - target_slices = [slice(None)] * d + source_slices: List[Any] = [slice(None)] * d + target_slices: List[Any] = [slice(None)] * d for i in range(k): - source_slices[target_axes[i]] = slice(index[k + i], index[k + i] + 1) - target_slices[target_axes[i]] = slice(index[i], index[i] + 1) + source_slices[target_axes[i]] = index[k + i] + target_slices[target_axes[i]] = index[i] out[tuple(target_slices)] = right_target[tuple(source_slices)] * left_matrix[index] return out From 9bdd43b183b80b056d54b6d7188eada3fa6c000c Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 3 Oct 2022 10:59:26 -0700 Subject: [PATCH 7/7] Optimize for `1` multiple --- cirq-core/cirq/linalg/transformations.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index e0d41851430..a99f9b375b0 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -143,6 +143,8 @@ def targeted_left_multiply( d = len(right_target.shape) nonzeros = np.flatnonzero(left_matrix) if len(nonzeros) == 1: + # This is just moving a slice from one place to another with everything else zeros, and an + # optional rescale. More efficient to operate on slices directly than do a full einsum. index = np.unravel_index(nonzeros[0], left_matrix.shape) if out is None: out = np.zeros_like(right_target) @@ -153,7 +155,10 @@ def targeted_left_multiply( for i in range(k): source_slices[target_axes[i]] = index[k + i] target_slices[target_axes[i]] = index[i] - out[tuple(target_slices)] = right_target[tuple(source_slices)] * left_matrix[index] + sleis = right_target[tuple(source_slices)] + if left_matrix[index] != 1: + sleis *= left_matrix[index] + out[tuple(target_slices)] = sleis return out work_indices = tuple(range(k))