Skip to content

Commit

Permalink
refactor(laplace): clone matrix in compute
Browse files Browse the repository at this point in the history
  • Loading branch information
mkroening committed Dec 17, 2024
1 parent d2b7fab commit a2f37fc
Showing 1 changed file with 14 additions and 21 deletions.
35 changes: 14 additions & 21 deletions examples/demo/src/laplace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! This module performs the Jacobi method for solving Laplace's differential equation.
use std::time::Instant;
use std::vec;
use std::{mem, vec};

use rayon::prelude::*;

Expand All @@ -23,31 +23,27 @@ pub fn laplace() {
assert!(residual < 0.001);
}

fn matrix_setup(size_x: usize, size_y: usize) -> vec::Vec<vec::Vec<f64>> {
let mut matrix = vec![vec![0.0; size_x * size_y]; 2];
fn matrix_setup(size_x: usize, size_y: usize) -> vec::Vec<f64> {
let mut matrix = vec![0.0; size_x * size_y];

// top row
for x in 0..size_x {
matrix[0][x] = 1.0;
matrix[1][x] = 1.0;
for f in matrix.iter_mut().take(size_x) {
*f = 1.0;
}

// bottom row
for x in 0..size_x {
matrix[0][(size_y - 1) * size_x + x] = 1.0;
matrix[1][(size_y - 1) * size_x + x] = 1.0;
matrix[(size_y - 1) * size_x + x] = 1.0;
}

// left row
for y in 0..size_y {
matrix[0][y * size_x] = 1.0;
matrix[1][y * size_x] = 1.0;
matrix[y * size_x] = 1.0;
}

// right row
for y in 0..size_y {
matrix[0][y * size_x + size_x - 1] = 1.0;
matrix[1][y * size_x + size_x - 1] = 1.0;
matrix[y * size_x + size_x - 1] = 1.0;
}

matrix
Expand Down Expand Up @@ -93,20 +89,17 @@ fn iteration(cur: &[f64], next: &mut [f64], size_x: usize, size_y: usize) {
});
}

pub fn compute(mut matrix: vec::Vec<vec::Vec<f64>>, size_x: usize, size_y: usize) -> (usize, f64) {
pub fn compute(matrix: vec::Vec<f64>, size_x: usize, size_y: usize) -> (usize, f64) {
let mut current = matrix;
let mut next = current.clone();
let mut counter = 0;

while counter < 1000 {
{
// allow a borrow and a reference to the same vector
let (current, next) = matrix.split_at_mut(1);

iteration(&current[0], &mut next[0], size_x, size_y);
}
matrix.swap(0, 1);
iteration(&current, &mut next, size_x, size_y);
mem::swap(&mut current, &mut next);

counter += 1;
}

(counter, get_residual(&matrix[0], size_x, size_y))
(counter, get_residual(&current, size_x, size_y))
}

0 comments on commit a2f37fc

Please sign in to comment.