Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograph with make plxpr #6645

Merged
merged 105 commits into from
Dec 5, 2024
Merged
Changes from 1 commit
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
5dfc38f
add user decompose function
andrijapau Oct 3, 2024
f1d9eee
code factor
andrijapau Oct 3, 2024
b24c97b
update changelog
andrijapau Oct 3, 2024
0d3bfb6
improve decomposition function
andrijapau Oct 4, 2024
0d66209
codefactor fix
andrijapau Oct 4, 2024
a25d6ba
added modified generator
andrijapau Oct 4, 2024
2ad96ac
copy code from catalyst
lillian542 Oct 4, 2024
c471c0b
remove catalyst from __init__ and transformer files
lillian542 Oct 4, 2024
49c0fac
replace many catalyst dependencies/mentions
lillian542 Oct 7, 2024
5dd1d42
copy tests from catalyst
lillian542 Oct 7, 2024
c94da87
update imports
lillian542 Oct 7, 2024
ab8ae7a
[ci skip]
lillian542 Oct 8, 2024
ac7d53e
Merge branch 'master' into add_autograph
lillian542 Oct 8, 2024
4a6bdf3
import autograph module in capture
lillian542 Oct 8, 2024
2c2f1dd
remove final catalyst dependencies from ag_primitives
lillian542 Oct 8, 2024
9b18b97
record queueing in FlatFn so we can access program_length in ag_primi…
lillian542 Oct 8, 2024
cf92f2a
most conditional tests work
lillian542 Oct 9, 2024
7a21764
initial removal of queueing dependency
lillian542 Oct 10, 2024
72b3e42
the current state of the tests
lillian542 Oct 10, 2024
fee16ff
remove index setting and logical operators
lillian542 Oct 11, 2024
3ccd0e0
formatting
lillian542 Oct 11, 2024
1ab03f3
remove fallback
lillian542 Oct 12, 2024
2bfaf72
remove disable_autograph
lillian542 Oct 12, 2024
42ce1e8
remove autograph_include
lillian542 Oct 12, 2024
d3a807c
Merge branch 'master' into autograph_ctrl_flow
lillian542 Oct 12, 2024
fa0b415
remove strict_conversion and ignore_fallbacks
lillian542 Oct 12, 2024
54b3036
Tidying up
lillian542 Oct 12, 2024
26cbde4
Tidying up more
lillian542 Oct 12, 2024
211bd81
fix more tests
lillian542 Oct 16, 2024
6be1ff5
Merge branch 'autograph_ctrl_flow' of github.com:PennyLaneAI/pennylan…
lillian542 Oct 16, 2024
75b0c3c
only cond in ag_primitives
lillian542 Oct 16, 2024
5bfecb4
clean up transformer and utils
lillian542 Oct 16, 2024
ac51fad
re-organize tests
lillian542 Oct 16, 2024
d4d7220
remove utils and clean up docstrings
lillian542 Oct 16, 2024
207ecc6
update changelog
lillian542 Oct 16, 2024
d2f6e34
add malt as dependency of PL
lillian542 Oct 16, 2024
da0cd1d
fix failing test
lillian542 Oct 16, 2024
cea27f6
Merge branch 'master' into autograph1
lillian542 Oct 16, 2024
9da4afa
package name is diastatic-malt
lillian542 Oct 16, 2024
fbff208
Merge branch 'autograph1' of github.com:PennyLaneAI/pennylane into au…
lillian542 Oct 16, 2024
e443173
fix decorator test
lillian542 Oct 16, 2024
2ee18ca
rename test file to avoid CI confusion
lillian542 Oct 16, 2024
3ab7c9c
add while_loop implementation
lillian542 Oct 16, 2024
900cfe1
add test file
lillian542 Oct 17, 2024
fc59744
a few more tests and docstrings updates
lillian542 Oct 17, 2024
a8946af
one more test for code coverage
lillian542 Oct 17, 2024
908d6b1
Merge branch 'master' into autograph1
lillian542 Oct 17, 2024
60cb38b
Update pennylane/capture/autograph/ag_primitives.py
lillian542 Oct 17, 2024
9c8c74d
add initial tests
lillian542 Oct 17, 2024
bfc39b2
Merge branch 'autograph1' into autograph_while_loop
lillian542 Oct 17, 2024
3972b89
some small test changes
lillian542 Oct 17, 2024
ef2d503
update changelog
lillian542 Oct 21, 2024
19130e9
xfail test that includes for loop
lillian542 Oct 21, 2024
4b73551
add for loop support and tests
lillian542 Oct 21, 2024
d5eaa8d
Merge branch 'master' into autograph1
lillian542 Oct 28, 2024
7da5662
Merge branch 'autograph1' into autograph_while_loop
lillian542 Oct 28, 2024
a03acf8
Merge branch 'master' into autograph1
lillian542 Nov 21, 2024
1a01dea
Apply suggestions from code review
lillian542 Nov 21, 2024
42c91f9
use inner_args to avoid taken arguments
lillian542 Nov 21, 2024
78307b5
update copyright year
lillian542 Nov 21, 2024
e36fa65
use functools.wraps
lillian542 Nov 21, 2024
998ce1a
replace qjit example with pl example
lillian542 Nov 21, 2024
83a934e
add import path for run_autograph, autograph_source
lillian542 Nov 21, 2024
36267d8
change import structure + update example
lillian542 Nov 21, 2024
c06e768
fix a couple docstring mistakes
lillian542 Nov 21, 2024
d6fae5d
Merge branch 'autograph1' into autograph_while_loop
lillian542 Nov 21, 2024
a63009d
Apply suggestions from code review
lillian542 Nov 21, 2024
793e7f4
remove unneeded check
lillian542 Nov 26, 2024
1f1f097
small test fixes
lillian542 Nov 26, 2024
705eb35
Merge branch 'autograph_while_loop' of github.com:PennyLaneAI/pennyla…
lillian542 Nov 26, 2024
c387988
Merge branch 'master' into autograph1
lillian542 Nov 26, 2024
84ca81b
Merge branch 'autograph1' into autograph_while_loop
lillian542 Nov 26, 2024
e8f23ab
Merge branch 'autograph_while_loop' into autograph_for_loop
lillian542 Nov 26, 2024
ed9ea11
remove source_info function
lillian542 Nov 26, 2024
0faba0e
reoraganize and update tests
lillian542 Nov 26, 2024
4c5d969
clean up error msgs and docstrings
lillian542 Nov 26, 2024
b981ade
update tests
lillian542 Nov 26, 2024
6432807
update changelog
lillian542 Nov 26, 2024
3d6d40e
pylint complaint
lillian542 Nov 26, 2024
9403d0d
add autograph to make_plxpr
lillian542 Nov 27, 2024
4173a31
update tests
lillian542 Nov 27, 2024
947a8ab
add autograph argument to make_plxpr
lillian542 Nov 27, 2024
558901b
remove old code from previous implementation
lillian542 Dec 3, 2024
c439ed0
Merge branch 'master' into autograph_with_make_plxpr
lillian542 Dec 3, 2024
f4d8b07
update docstring
lillian542 Dec 3, 2024
5ab68c7
add malt to doc requirements
lillian542 Dec 3, 2024
be51555
make usage details collapsable
lillian542 Dec 4, 2024
ee22c0c
update changelog
lillian542 Dec 4, 2024
a73c584
does this make a collapsable tab?
lillian542 Dec 4, 2024
2682f18
Merge branch 'master' into autograph_with_make_plxpr
lillian542 Dec 4, 2024
0f57ce4
this is my least favourite game to play with sphinx
lillian542 Dec 4, 2024
c3b3c5e
Update doc/releases/changelog-dev.md
lillian542 Dec 4, 2024
4e2a206
try to add basic details section
lillian542 Dec 4, 2024
aa78fba
Merge branch 'autograph_with_make_plxpr' of github.com:PennyLaneAI/pe…
lillian542 Dec 4, 2024
7c8520d
add some of the content back in
lillian542 Dec 4, 2024
8f9189a
changelog update
lillian542 Dec 4, 2024
b6f2695
fix merge conflicts
lillian542 Dec 4, 2024
3bdd4a6
fingers crossed it all renders now
lillian542 Dec 4, 2024
f4c7a9c
add space between example cells
lillian542 Dec 4, 2024
e30dbdb
Apply suggestions from code review
lillian542 Dec 4, 2024
42a2656
re-organize changelog
lillian542 Dec 4, 2024
3234898
Merge branch 'master' into autograph_with_make_plxpr
lillian542 Dec 4, 2024
f845987
don't reorganize changelog
lillian542 Dec 5, 2024
3649d94
don't reorganize changelog pt2
lillian542 Dec 5, 2024
945a5bd
Merge branch 'master' into autograph_with_make_plxpr
lillian542 Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
formatting
lillian542 committed Oct 11, 2024
commit 3ccd0e0a328deeb943346ccdcb7379836730a5a9
2 changes: 2 additions & 0 deletions pennylane/capture/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
@@ -395,6 +395,7 @@ def while_stmt(loop_test, loop_body, get_state, set_state, symbol_names, _opts):

set_state(results)


def get_source_code_info(tb_frame):
"""Attempt to obtain original source code information for an exception raised within AutoGraph
transformed code.
@@ -505,6 +506,7 @@ def qnode_call_wrapper():

return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)


class PRange:
"""PennyLane range object.

3 changes: 2 additions & 1 deletion pennylane/capture/autograph/transformer.py
Original file line number Diff line number Diff line change
@@ -24,10 +24,11 @@
import inspect
from contextlib import ContextDecorator

import pennylane as qml
from malt.core import ag_ctx, converter
from malt.impl.api import PyToPy

import pennylane as qml

from . import ag_primitives
from .utils import AutoGraphError

75 changes: 46 additions & 29 deletions tests/capture/test_autograph.py
Original file line number Diff line number Diff line change
@@ -21,10 +21,10 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jax.core import eval_jaxpr

# from catalyst import debug, qjit, vmap
from jax.errors import TracerBoolConversionError
from jax.core import eval_jaxpr
from numpy.testing import assert_allclose

import pennylane as qml
@@ -38,7 +38,6 @@
)
from pennylane.capture.autograph.utils import AutoGraphError, CompileError, dummy_func


check_cache = TRANSFORMER.has_cache

# pylint: disable=import-outside-toplevel
@@ -567,15 +566,19 @@ def circuit(n):
return res

# can't convert to jaxpr without autorgraph
with pytest.raises(jax.errors.TracerBoolConversionError, match="Attempted boolean conversion of traced array"):
with pytest.raises(
jax.errors.TracerBoolConversionError,
match="Attempted boolean conversion of traced array",
):
jax.make_jaxpr(circuit)(1)

# with autograph we can convert to jaxpr
circuit = run_autograph(circuit)
jaxpr = jax.make_jaxpr(circuit)(0)
assert "cond" in str(jaxpr)

def res(x): return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)
def res(x):
return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)

# evaluating the jaxpr gives expected results
assert res(0) == [0]
@@ -603,7 +606,8 @@ def circuit(x):
jaxpr = jax.make_jaxpr(ag_circuit)(0)
assert "cond" in str(jaxpr)

def res(x): return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]
def res(x):
return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]

assert res(4) == 16
assert res(2) == 4
@@ -628,7 +632,8 @@ def circuit(x):
jaxpr = jax.make_jaxpr(ag_circuit)(0)
assert "cond" in str(jaxpr)

def res(x): return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]
def res(x):
return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]

assert res(5) == 40
assert res(3) == 12
@@ -651,7 +656,8 @@ def circuit(x):
jaxpr = jax.make_jaxpr(ag_circuit)(0)
assert "cond" in str(jaxpr)

def res(x): return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]
def res(x):
return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]

# pylint: disable=singleton-comparison
assert res(3) == 0
@@ -707,7 +713,8 @@ def f(x: int):
jaxpr = jax.make_jaxpr(ag_circuit)(0)
assert "cond" in str(jaxpr)

def res(x): return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]
def res(x):
return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]

assert res(1) == 25
assert res(0) == 60
@@ -731,7 +738,8 @@ def f(x: float):
ag_circuit = run_autograph(f)
jaxpr = jax.make_jaxpr(ag_circuit)(0)

def res(x): return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]
def res(x):
return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]

# assert capfd.readouterr() == ("", "")

@@ -759,7 +767,9 @@ def f(switch: bool):
ag_circuit = run_autograph(f)
jaxpr = jax.make_jaxpr(ag_circuit)(0)

def res(x): return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]
def res(x):
return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)[0]

res(1)


@@ -789,8 +799,10 @@ def f(params):
return qml.expval(qml.PauliZ(0))

ag_circuit = run_autograph(f)
jaxpr = jax.make_jaxpr(ag_circuit)([1., 2., 3.])
def res(params): return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, params)
jaxpr = jax.make_jaxpr(ag_circuit)([1.0, 2.0, 3.0])

def res(params):
return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, params)

result = f(jnp.array([0.0, 1 / 4 * jnp.pi, 2 / 4 * jnp.pi]))
print(result)
@@ -1069,8 +1081,10 @@ def f(params):
ag_circuit = run_autograph(f)
jaxpr = jax.make_jaxpr(ag_circuit)(jnp.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]))

params = jnp.array([[0.0, 1 / 4 * jnp.pi], [2 / 4 * jnp.pi, 3 / 4 * jnp.pi], [jnp.pi, 2 * jnp.pi]])
result =eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, params)
params = jnp.array(
[[0.0, 1 / 4 * jnp.pi], [2 / 4 * jnp.pi, 3 / 4 * jnp.pi], [jnp.pi, 2 * jnp.pi]]
)
result = eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, params)

assert np.allclose(result, [jnp.sqrt(2) / 2, -jnp.sqrt(2) / 2, -1.0])

@@ -1430,7 +1444,7 @@ def f(p):
return qml.probs()

ag_circuit = run_autograph(f)
jaxpr = jax.make_jaxpr(ag_circuit)(0.)
jaxpr = jax.make_jaxpr(ag_circuit)(0.0)
assert "while_loop" in str(jaxpr)

result = eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2.0**4)[0]
@@ -1716,7 +1730,7 @@ def loop(i, acc):

assert f() == 3

# @pytest.mark.usefixtures("autograph_strict_conversion")
# @pytest.mark.usefixtures("autograph_strict_conversion")
def test_cond_if_for_loop_for(self, monkeypatch):
"""Test Python conditionals and loops together with their Catalyst counterparts."""

@@ -1754,6 +1768,8 @@ def even():

assert f(2) == 18
assert f(3) == 0


#
#
#
@@ -1801,6 +1817,7 @@ def even():
# assert g() == 36.4
#


class TestAutographInclude:
"""Test include modules to autograph conversion"""

@@ -1848,19 +1865,19 @@ def fn(x: float, n: int):
class TestDecorators:
"""Test if Autograph works when applied to a decorated function"""

# def test_vmap(self):
# """Test if Autograph works when applied to a decorated function with vmap"""
#
# def workflow(axes_dct):
# return axes_dct["x"] + axes_dct["y"]
#
# expected = jnp.array([1, 2, 3, 4, 5])
#
# result = qjit(vmap(workflow, in_axes=({"x": None, "y": 0},)), autograph=True)(
# {"x": 1, "y": jnp.arange(5)}
# )
# assert jnp.allclose(result, expected)
#
# def test_vmap(self):
# """Test if Autograph works when applied to a decorated function with vmap"""
#
# def workflow(axes_dct):
# return axes_dct["x"] + axes_dct["y"]
#
# expected = jnp.array([1, 2, 3, 4, 5])
#
# result = qjit(vmap(workflow, in_axes=({"x": None, "y": 0},)), autograph=True)(
# {"x": 1, "y": jnp.arange(5)}
# )
# assert jnp.allclose(result, expected)
#
def test_cond(self):
"""Test if Autograph works when applied to a decorated function with cond"""