diff --git a/examples/lorenz_attractor.py b/examples/lorenz_attractor.py new file mode 100644 index 0000000..9410e78 --- /dev/null +++ b/examples/lorenz_attractor.py @@ -0,0 +1,69 @@ +""" Plot3 example with Lorenz attractor """ + +import numpy as np +import mlpyqtgraph as mpg + + +def lorenz(x, *, s=10, r=28, b=2.667): + """ + Parameters + ---------- + x : array-like, shape (3,) + Point of interest in three-dimensional space. + s, r, b : float + Parameters defining the Lorenz attractor. + + Returns + ------- + x_dot : array, shape (3,) + Values of the Lorenz attractor's partial derivatives at x. + """ + return np.array([ + s*(x[1] - x[0]), + r*x[0] - x[1] - x[0]*x[2], + x[0]*x[1] - b*x[2] + ]) + + +def euler(dxdt, x0, dt=0.005, num_steps=10_000): + """ + Euler integration + + Parameters + ---------- + dxdt : callable + Function that takes a single argument `x` with shape `(n,)` and + returns an array with the same shape, representing the derivative + of `x` with respect to time. + x0 : array-like, shape `(n,)` + Initial condition. + dt : float, optional + Time step. Defaults to 0.005. + num_steps : int, optional + Number of steps to run the integration. Defaults to 10_000. + + Returns + ------- + x : array, shape `(n, num_steps + 1)` + Path taken by the system during the integration. + """ + x = np.empty((num_steps + 1, len(x0))) + x[0] = x0 + for i in range(num_steps): + x[i + 1] = x[i] + dt*dxdt(x[i]) + return x.T + + +@mpg.plotter(antialiasing=True) +def main(): + """ Plot Lorenz attractor """ + x, y, z = euler(dxdt=lorenz, x0=(0., 1., 1.05)) + + mpg.figure(title='Lorenz attractor', layout_type='Qt') + mpg.plot3(x, y, z, projection='orthographic') + ax = mpg.gca() + ax.azimuth = 315 + + +if __name__ == '__main__': + main() diff --git a/mlpyqtgraph/axes.py b/mlpyqtgraph/axes.py index 7246729..2b64ab6 100644 --- a/mlpyqtgraph/axes.py +++ b/mlpyqtgraph/axes.py @@ -308,6 +308,10 @@ class Axis3D(gl.GLGraphicsItem.GLGraphicsItem): 'color': (0, 0, 0, 1), 'antialias': True, 'width': 1, + } + + grid_line_options = { + **default_line_options, 'glOptions': glOption_lines, } @@ -339,7 +343,7 @@ def set_projection_method(self, *coords, method='orthographic'): distance = 0.75*object_size/math.tan(0.5*field_of_view/180.0*math.pi) self.view().setCameraParams(fov=field_of_view, distance=distance) - def add(self, *args, **kwargs): + def surf(self, *args, **kwargs): """ Adds a 3D surface plot item to the view widget """ kwargs = dict(self.default_surface_options, **kwargs) surface = gl.GLSurfacePlotItem(*args, **kwargs) @@ -351,7 +355,7 @@ def add(self, *args, **kwargs): def calculate_ax_coord_lims(self, x, y, z): """ Calculates the axis coordinates limits """ - coords = dict(coord_generator(num_ticks=6, x=x, y=y, z=z)) + coords = dict(coord_generator(x=x, y=y, z=z)) limits = dict(limit_generator(limit_ratio=0.05, **coords)) return coords, limits @@ -359,24 +363,40 @@ def update_grid_axes(self, *args, **kwargs): """ Plots the grid axes """ coords, limits = self.calculate_ax_coord_lims(*args) self.grid_axes.setData(coords=coords, limits=limits) - projection_method = kwargs.get('projection', 'perspective') - self.view().setCameraPosition(**self.grid_axes.best_camera(method=projection_method)) + projection = kwargs.get('projection', 'perspective') + print(projection) + self.view().setCameraPosition(**self.grid_axes.best_camera(method=projection)) def add_grid_lines(self, *args): """ Plots all grid lines """ x, y, z = args[:3] rows, columns = z.shape for row in range(rows): - self.add_single_grid_line(x[row]*np.ones(columns), y, z[row]) + self.add_line( + x[row]*np.ones(columns), y, z[row], + **self.grid_line_options + ) for col in range(columns): - self.add_single_grid_line(x, y[col]*np.ones(rows), z[:, col]) + self.add_line( + x, y[col]*np.ones(rows), z[:, col], + **self.grid_line_options + ) - def add_single_grid_line(self, x, y, z): + def add_line(self, *args, **kwargs): """ Plots a single grid line for given coordinates """ - points = np.column_stack((x, y, z)) - line = gl.GLLinePlotItem(pos=points, **self.default_line_options) + points = np.column_stack(args) + line = gl.GLLinePlotItem(pos=points, **kwargs) self.view().addItem(line) + def line(self, *args, **kwargs): + """ Plots a single grid line for given coordinates """ + kwargs = dict(self.default_line_options, **kwargs) + lines_kwargs = dict(kwargs) + lines_kwargs.pop('projection') + self.add_line(*args, **lines_kwargs) + self.set_projection_method(*args, method=kwargs['projection']) + self.update_grid_axes(*args, **kwargs) + def delete(self): """ Closes the axis """ diff --git a/mlpyqtgraph/ml_functions.py b/mlpyqtgraph/ml_functions.py index 37f8f15..22744e2 100644 --- a/mlpyqtgraph/ml_functions.py +++ b/mlpyqtgraph/ml_functions.py @@ -51,4 +51,10 @@ def surf(*args, **kwargs): """ Plots a 3D surface """ gcf().change_layout('Qt') gcf().create_axis(axis_type='3D') - gca().add(*args, **kwargs) + gca().surf(*args, **kwargs) + +def plot3(*args, **kwargs): + """ Plots a 3D line """ + gcf().change_layout('Qt') + gcf().create_axis(axis_type='3D') + gca().line(*args, **kwargs) diff --git a/mlpyqtgraph/workers.py b/mlpyqtgraph/workers.py index 9eb33ec..7d3f64e 100644 --- a/mlpyqtgraph/workers.py +++ b/mlpyqtgraph/workers.py @@ -11,7 +11,8 @@ class AxisWorker(containers.WorkerItem): factory = containers.WorkerItem.get_factory() row = factory.attribute() column = factory.attribute() - add = factory.method() + surf = factory.method() + line = factory.method() add_legend = factory.method() grid = factory.attribute() xlim = factory.attribute()