Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: std::bad_cast when running examples #42

Open
ytwang-acs opened this issue Nov 12, 2024 · 1 comment
Open

RuntimeError: std::bad_cast when running examples #42

ytwang-acs opened this issue Nov 12, 2024 · 1 comment

Comments

@ytwang-acs
Copy link

Hi Jax-Fem community,
I just installed Jax-fem on my Mac following the instructions.
While I tried to run the examples, all have this RuntimeError
For example, when I ran python -m demos.wave.example

[11-12 13:50:24][DEBUG] jax_fem: Computing shape function values, gradients, etc.
[11-12 13:50:24][DEBUG] jax_fem: ele_type = TRI3, quad_points.shape = (num_quads, dim) = (3, 2)
[11-12 13:50:24][DEBUG] jax_fem: face_quad_points.shape = (num_faces, num_face_quads, dim) = (3, 2, 2)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/tong/Desktop/github_tryouts/jax-fem/demos/wave/example.py", line 184, in <module>
    main_fns()
  File "/Users/tong/Desktop/github_tryouts/jax-fem/demos/wave/example.py", line 166, in main_fns
    problem = wave(mesh, vec=1, dim=2, ele_type = ele_type, gauss_order=2, dirichlet_bc_info = dirichlet_bc_info)
  File "<string>", line 11, in __init__
  File "/Users/tong/Desktop/github_tryouts/jax-fem/jax_fem/problem.py", line 37, in __post_init__
    self.fes = [FiniteElement(mesh=self.mesh[I], 
  File "/Users/tong/Desktop/github_tryouts/jax-fem/jax_fem/problem.py", line 37, in <listcomp>
    self.fes = [FiniteElement(mesh=self.mesh[I], 
  File "<string>", line 10, in __init__
  File "/Users/tong/Desktop/github_tryouts/jax-fem/jax_fem/fe.py", line 79, in __post_init__
    self.node_inds_list, self.vec_inds_list, self.vals_list = self.Dirichlet_boundary_conditions(self.dirichlet_bc_info)
  File "/Users/tong/Desktop/github_tryouts/jax-fem/jax_fem/fe.py", line 221, in Dirichlet_boundary_conditions
    node_inds = onp.argwhere(jax.vmap(location_fn)(self.mesh.points, np.arange(self.num_total_nodes))).reshape(-1)
  File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3582, in arange
    return lax.iota(dtype, start)
  File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1321, in iota
    return broadcasted_iota(dtype, (size,), 0)
  File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1331, in broadcasted_iota
    return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
  File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/core.py", line 416, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/core.py", line 420, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/core.py", line 921, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
  File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/array.py", line 1146, in _array_global_result_handler
    return xc.array_result_handler(
RuntimeError: std::bad_cast
@tianjuxue
Copy link
Collaborator

I am running JAX-FEM with JAX (version=0.4.30). Can you try to install this specific version of JAX and see if the error is still there?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants