Skip to content

Commit

Permalink
Add extra library test decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
manoelmarques committed Apr 16, 2021
1 parent ed4770e commit 5b41d99
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 240 deletions.
8 changes: 6 additions & 2 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

""" ML test packages """

from .machine_learning_test_case import QiskitMachineLearningTestCase
from .machine_learning_test_case import (QiskitMachineLearningTestCase,
requires_extra_library)

__all__ = ['QiskitMachineLearningTestCase']
__all__ = [
'QiskitMachineLearningTestCase',
'requires_extra_library'
]
117 changes: 56 additions & 61 deletions test/algorithms/distribution_learners/qgan/test_qgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import unittest
import warnings
import tempfile
from test import QiskitMachineLearningTestCase
from test import QiskitMachineLearningTestCase, requires_extra_library

from qiskit import BasicAer
from qiskit.circuit.library import UniformDistribution, RealAmplitudes
from qiskit.utils import algorithm_globals, QuantumInstance
from qiskit.exceptions import MissingOptionalLibraryError
from qiskit.algorithms.optimizers import CG, COBYLA
from qiskit.opflow.gradients import Gradient
from qiskit_machine_learning.algorithms import (NumPyDiscriminator,
Expand Down Expand Up @@ -116,72 +115,68 @@ def test_qgan_training(self):
trained_qasm = self.qgan.run(self.qi_qasm)
self.assertAlmostEqual(trained_qasm['rel_entr'], trained_statevector['rel_entr'], delta=0.1)

@requires_extra_library
def test_qgan_training_run_algo_torch(self):
"""Test QGAN training using a PyTorch discriminator."""
try:
# Set number of qubits per data dimension as list of k qubit values[#q_0,...,#q_k-1]
num_qubits = [2]
# Batch size
batch_size = 100
# Set number of training epochs
num_epochs = 5
_qgan = QGAN(self._real_data,
self._bounds,
num_qubits,
batch_size,
num_epochs,
discriminator=PyTorchDiscriminator(n_features=len(num_qubits)),
snapshot_dir=None)
_qgan.seed = self.seed
_qgan.set_generator()
trained_statevector = _qgan.run(QuantumInstance(
BasicAer.get_backend('statevector_simulator'),
seed_simulator=algorithm_globals.random_seed,
seed_transpiler=algorithm_globals.random_seed))
trained_qasm = _qgan.run(QuantumInstance(BasicAer.get_backend('qasm_simulator'),
seed_simulator=algorithm_globals.random_seed,
seed_transpiler=algorithm_globals.random_seed))
self.assertAlmostEqual(trained_qasm['rel_entr'],
trained_statevector['rel_entr'], delta=0.1)
except MissingOptionalLibraryError:
self.skipTest('pytorch not installed, skipping test')
# Set number of qubits per data dimension as list of k qubit values[#q_0,...,#q_k-1]
num_qubits = [2]
# Batch size
batch_size = 100
# Set number of training epochs
num_epochs = 5
_qgan = QGAN(self._real_data,
self._bounds,
num_qubits,
batch_size,
num_epochs,
discriminator=PyTorchDiscriminator(n_features=len(num_qubits)),
snapshot_dir=None)
_qgan.seed = self.seed
_qgan.set_generator()
trained_statevector = _qgan.run(QuantumInstance(
BasicAer.get_backend('statevector_simulator'),
seed_simulator=algorithm_globals.random_seed,
seed_transpiler=algorithm_globals.random_seed))
trained_qasm = _qgan.run(QuantumInstance(BasicAer.get_backend('qasm_simulator'),
seed_simulator=algorithm_globals.random_seed,
seed_transpiler=algorithm_globals.random_seed))
self.assertAlmostEqual(trained_qasm['rel_entr'],
trained_statevector['rel_entr'], delta=0.1)

@requires_extra_library
def test_qgan_training_run_algo_torch_multivariate(self):
"""Test QGAN training using a PyTorch discriminator, for multivariate distributions."""
try:
# Set number of qubits per data dimension as list of k qubit values[#q_0,...,#q_k-1]
num_qubits = [1, 2]
# Batch size
batch_size = 100
# Set number of training epochs
num_epochs = 5
# Set number of qubits per data dimension as list of k qubit values[#q_0,...,#q_k-1]
num_qubits = [1, 2]
# Batch size
batch_size = 100
# Set number of training epochs
num_epochs = 5

# Reshape data in a multi-variate fashion
# (two independent identically distributed variables,
# each represented by half of the generated samples)
real_data = self._real_data.reshape((-1, 2))
bounds = [self._bounds, self._bounds]
# Reshape data in a multi-variate fashion
# (two independent identically distributed variables,
# each represented by half of the generated samples)
real_data = self._real_data.reshape((-1, 2))
bounds = [self._bounds, self._bounds]

_qgan = QGAN(real_data,
bounds,
num_qubits,
batch_size,
num_epochs,
discriminator=PyTorchDiscriminator(n_features=len(num_qubits)),
snapshot_dir=None)
_qgan.seed = self.seed
_qgan.set_generator()
trained_statevector = _qgan.run(QuantumInstance(
BasicAer.get_backend('statevector_simulator'),
seed_simulator=algorithm_globals.random_seed,
seed_transpiler=algorithm_globals.random_seed))
trained_qasm = _qgan.run(QuantumInstance(BasicAer.get_backend('qasm_simulator'),
seed_simulator=algorithm_globals.random_seed,
seed_transpiler=algorithm_globals.random_seed))
self.assertAlmostEqual(trained_qasm['rel_entr'],
trained_statevector['rel_entr'], delta=0.1)
except MissingOptionalLibraryError:
self.skipTest('pytorch not installed, skipping test')
_qgan = QGAN(real_data,
bounds,
num_qubits,
batch_size,
num_epochs,
discriminator=PyTorchDiscriminator(n_features=len(num_qubits)),
snapshot_dir=None)
_qgan.seed = self.seed
_qgan.set_generator()
trained_statevector = _qgan.run(QuantumInstance(
BasicAer.get_backend('statevector_simulator'),
seed_simulator=algorithm_globals.random_seed,
seed_transpiler=algorithm_globals.random_seed))
trained_qasm = _qgan.run(QuantumInstance(BasicAer.get_backend('qasm_simulator'),
seed_simulator=algorithm_globals.random_seed,
seed_transpiler=algorithm_globals.random_seed))
self.assertAlmostEqual(trained_qasm['rel_entr'],
trained_statevector['rel_entr'], delta=0.1)

def test_qgan_training_run_algo_numpy(self):
"""Test QGAN training using a NumPy discriminator."""
Expand Down
Loading

0 comments on commit 5b41d99

Please sign in to comment.