Skip to content

Commit

Permalink
feat: Add plot3 functionality including example
Browse files Browse the repository at this point in the history
Signed-off-by: Sietze van Buuren <[email protected]>
  • Loading branch information
swvanbuuren committed Nov 9, 2024
1 parent d65e962 commit 4fa0510
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 11 deletions.
69 changes: 69 additions & 0 deletions examples/lorenz_attractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
""" Plot3 example with Lorenz attractor """

import numpy as np
import mlpyqtgraph as mpg


def lorenz(x, *, s=10, r=28, b=2.667):
"""
Parameters
----------
x : array-like, shape (3,)
Point of interest in three-dimensional space.
s, r, b : float
Parameters defining the Lorenz attractor.
Returns
-------
x_dot : array, shape (3,)
Values of the Lorenz attractor's partial derivatives at x.
"""
return np.array([
s*(x[1] - x[0]),
r*x[0] - x[1] - x[0]*x[2],
x[0]*x[1] - b*x[2]
])


def euler(dxdt, x0, dt=0.005, num_steps=10_000):
"""
Euler integration
Parameters
----------
dxdt : callable
Function that takes a single argument `x` with shape `(n,)` and
returns an array with the same shape, representing the derivative
of `x` with respect to time.
x0 : array-like, shape `(n,)`
Initial condition.
dt : float, optional
Time step. Defaults to 0.005.
num_steps : int, optional
Number of steps to run the integration. Defaults to 10_000.
Returns
-------
x : array, shape `(n, num_steps + 1)`
Path taken by the system during the integration.
"""
x = np.empty((num_steps + 1, len(x0)))
x[0] = x0
for i in range(num_steps):
x[i + 1] = x[i] + dt*dxdt(x[i])
return x.T


@mpg.plotter(antialiasing=True)
def main():
""" Plot Lorenz attractor """
x, y, z = euler(dxdt=lorenz, x0=(0., 1., 1.05))

mpg.figure(title='Lorenz attractor', layout_type='Qt')
mpg.plot3(x, y, z, projection='orthographic')
ax = mpg.gca()
ax.azimuth = 315


if __name__ == '__main__':
main()
38 changes: 29 additions & 9 deletions mlpyqtgraph/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ class Axis3D(gl.GLGraphicsItem.GLGraphicsItem):
'color': (0, 0, 0, 1),
'antialias': True,
'width': 1,
}

grid_line_options = {
**default_line_options,
'glOptions': glOption_lines,
}

Expand Down Expand Up @@ -339,7 +343,7 @@ def set_projection_method(self, *coords, method='orthographic'):
distance = 0.75*object_size/math.tan(0.5*field_of_view/180.0*math.pi)
self.view().setCameraParams(fov=field_of_view, distance=distance)

def add(self, *args, **kwargs):
def surf(self, *args, **kwargs):
""" Adds a 3D surface plot item to the view widget """
kwargs = dict(self.default_surface_options, **kwargs)
surface = gl.GLSurfacePlotItem(*args, **kwargs)
Expand All @@ -351,32 +355,48 @@ def add(self, *args, **kwargs):

def calculate_ax_coord_lims(self, x, y, z):
""" Calculates the axis coordinates limits """
coords = dict(coord_generator(num_ticks=6, x=x, y=y, z=z))
coords = dict(coord_generator(x=x, y=y, z=z))
limits = dict(limit_generator(limit_ratio=0.05, **coords))
return coords, limits

def update_grid_axes(self, *args, **kwargs):
""" Plots the grid axes """
coords, limits = self.calculate_ax_coord_lims(*args)
self.grid_axes.setData(coords=coords, limits=limits)
projection_method = kwargs.get('projection', 'perspective')
self.view().setCameraPosition(**self.grid_axes.best_camera(method=projection_method))
projection = kwargs.get('projection', 'perspective')
print(projection)
self.view().setCameraPosition(**self.grid_axes.best_camera(method=projection))

def add_grid_lines(self, *args):
""" Plots all grid lines """
x, y, z = args[:3]
rows, columns = z.shape
for row in range(rows):
self.add_single_grid_line(x[row]*np.ones(columns), y, z[row])
self.add_line(
x[row]*np.ones(columns), y, z[row],
**self.grid_line_options
)
for col in range(columns):
self.add_single_grid_line(x, y[col]*np.ones(rows), z[:, col])
self.add_line(
x, y[col]*np.ones(rows), z[:, col],
**self.grid_line_options
)

def add_single_grid_line(self, x, y, z):
def add_line(self, *args, **kwargs):
""" Plots a single grid line for given coordinates """
points = np.column_stack((x, y, z))
line = gl.GLLinePlotItem(pos=points, **self.default_line_options)
points = np.column_stack(args)
line = gl.GLLinePlotItem(pos=points, **kwargs)
self.view().addItem(line)

def line(self, *args, **kwargs):
""" Plots a single grid line for given coordinates """
kwargs = dict(self.default_line_options, **kwargs)
lines_kwargs = dict(kwargs)
lines_kwargs.pop('projection')
self.add_line(*args, **lines_kwargs)
self.set_projection_method(*args, method=kwargs['projection'])
self.update_grid_axes(*args, **kwargs)

def delete(self):
""" Closes the axis """

Expand Down
8 changes: 7 additions & 1 deletion mlpyqtgraph/ml_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,10 @@ def surf(*args, **kwargs):
""" Plots a 3D surface """
gcf().change_layout('Qt')
gcf().create_axis(axis_type='3D')
gca().add(*args, **kwargs)
gca().surf(*args, **kwargs)

def plot3(*args, **kwargs):
""" Plots a 3D line """
gcf().change_layout('Qt')
gcf().create_axis(axis_type='3D')
gca().line(*args, **kwargs)
3 changes: 2 additions & 1 deletion mlpyqtgraph/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class AxisWorker(containers.WorkerItem):
factory = containers.WorkerItem.get_factory()
row = factory.attribute()
column = factory.attribute()
add = factory.method()
surf = factory.method()
line = factory.method()
add_legend = factory.method()
grid = factory.attribute()
xlim = factory.attribute()
Expand Down

0 comments on commit 4fa0510

Please sign in to comment.