-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Examples failed #24
Comments
hey @blattnem ! Thanks for reporting this one. I might need a bit more info on your enviroment.
Thanks! |
|
The JAX support for Mac systems is not very extended, it actually depends on the underlying hardware to my understanding. For some ARM based CPU's might be good but I am not sure about sillicon based ones. Also this would be for CPU's but might be more tricky if you are trying GPU. Any chance you have used JAX on that system for something else? Has that worked fine? I think you might need to try to install and get working JAX separately and then go into My 2cents: Generally Linux would be a much less painful setup so if you have access to one I would recommend trying it there. Otherwise you can always use Colaboratory which is free and should give all you need. |
Hi Pablo |
I install lightweight_mmm with pip |
Let me try to reproduce and get back to you. When you said you had to install python 3.8, what version were you using before? I have a couple setups with 3.7 and 3.9 where things seem to be running fine, but we definitely want to support 3.8 as well so I will look into it. |
I have python 3.10 on my linux box. Installed 3.8 because the dependency on tensorflow 2.7.2. |
This are my libs when running your example notebooks jax 0.3.13
|
...and the same with python 3.9. I guess its a jax issue. What jax version are you using? |
Okay, thank you for that. Let me dig deeper into this one and get back to you. |
Looks like this is fixed by jax-0.3.14 & jaxlib-0.3.14 (I was having the same issue yesterday with the versions above) |
Hi there. Great news! Jax 0.3.14 does the job! Thanks. |
I installed all requirements in a env. Running the https://github.com/google/lightweight_mmm/blob/main/examples/simple_end_to_end_demo.ipynb example fails executing mmm.fit(.....)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:597, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
595 else:
596 if self.chain_method == "sequential":
--> 597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
599 states, last_state = pmap(partial_map_fn)(map_args)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:160, in _laxmap(f, xs)
158 for i in range(n):
159 x = jit(_get_value_from_index)(xs, i)
--> 160 ys.append(f(x))
162 return tree_map(lambda *args: jnp.stack(args), *ys)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
701 # vectorized
702 else:
703 rng_key, rng_key_init_model = jnp.swapaxes(
704 vmap(random.split)(rng_key), 0, 1
705 )
--> 706 init_params = self._init_state(
707 rng_key_init_model, model_args, model_kwargs, init_params
708 )
709 if self._potential_fn and init_params is None:
710 raise ValueError(
711 "Valid value of
init_params
must be provided with" "potential_fn
."712 )
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
--> 652 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
653 rng_key,
654 self._model,
655 dynamic_args=True,
656 init_strategy=self._init_strategy,
657 model_args=model_args,
658 model_kwargs=model_kwargs,
659 forward_mode_differentiation=self._forward_mode_differentiation,
660 )
661 if self._init_fn is None:
662 self._init_fn, self._sample_fn = hmc(
663 potential_fn_gen=potential_fn,
664 kinetic_fn=self._kinetic_fn,
665 algo=self._algo,
666 )
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
652 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 654 (init_params, pe, grad), is_valid = find_valid_initial_params(
655 rng_key,
656 substitute(
657 model,
658 data={
659 k: site["value"]
660 for k, site in model_trace.items()
661 if site["type"] in ["plate"]
662 },
663 ),
664 init_strategy=init_strategy,
665 enum=has_enumerate_support,
666 model_args=model_args,
667 model_kwargs=model_kwargs,
668 prototype_params=prototype_params,
669 forward_mode_differentiation=forward_mode_differentiation,
670 validate_grad=validate_grad,
671 )
673 if not_jax_tracer(is_valid):
674 if device_get(~jnp.all(is_valid)):
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
393 # Handle possible vectorization
394 if rng_key.ndim == 1:
--> 395 (init_params, pe, z_grad), is_valid = _find_valid_params(
396 rng_key, exit_early=True
397 )
398 else:
399 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.._find_valid_params(rng_key, exit_early)
377 init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False)
378 if exit_early and not_jax_tracer(rng_key):
379 # Early return if valid params found. This is only helpful for single chain,
380 # where we can avoid compiling body_fn in while_loop.
--> 381 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
382 if not_jax_tracer(is_valid):
383 if device_get(is_valid):
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params..body_fn(state)
364 z_grad = jacfwd(potential_fn)(params)
365 else:
--> 366 pe, z_grad = value_and_grad(potential_fn)(params)
367 z_grad_flat = ravel_pytree(z_grad)[0]
368 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
244 substituted_model = substitute(
245 model, substitute_fn=partial(unconstrain_reparam, params)
246 )
247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density(
249 substituted_model, model_args, model_kwargs, {}
250 )
251 return -log_joint
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params)
50 """
51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
52 latent values
params
.(...)
59 :return: log of joint density and a corresponding model trace
60 """
61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
63 log_joint = jnp.zeros(())
64 for site in model_trace.values():
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return:
OrderedDict
containing the execution trace.170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/lightweight_mmm/models.py:187, in media_mix_model(media_data, target_data, cost_prior, degrees_seasonality, frequency, transform_function, transform_kwargs, weekday_seasonality, extra_features)
182 with numpyro.plate(name="beta_trend_plate", size=n_geos):
183 beta_trend = numpyro.sample(
184 name="beta_trend",
185 fn=dist.Normal(loc=0., scale=1.))
--> 187 expo_trend = numpyro.sample(
188 name="expo_trend",
189 fn=dist.Beta(concentration1=1., concentration0=1.))
191 with numpyro.plate(
192 name="channel_media_plate",
193 size=n_channels,
194 dim=-2 if media_data.ndim == 3 else -1):
195 beta_media = numpyro.sample(
196 name="channel_beta_media" if media_data.ndim == 3 else "beta_media",
197 fn=dist.HalfNormal(scale=cost_prior))
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:219, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
204 initial_msg = {
205 "type": "sample",
206 "name": name,
(...)
215 "infer": {} if infer is None else infer,
216 }
218 # ...and use apply_stack to send it to the Messengers
--> 219 msg = apply_stack(initial_msg)
220 return msg["value"]
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the "stop" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
50 if msg.get("stop"):
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/handlers.py:789, in substitute.process_message(self, msg)
787 value = self.data.get(msg["name"])
788 else:
--> 789 value = self.substitute_fn(msg)
791 if value is not None:
792 msg["value"] = value
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:216, in _unconstrain_reparam(params, site)
213 return p
214 value = t(p)
--> 216 log_det = t.log_abs_det_jacobian(p, value)
217 log_det = sum_rightmost(
218 log_det, jnp.ndim(log_det) - jnp.ndim(value) + len(site["fn"].event_shape)
219 )
220 if site["scale"] is not None:
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/distributions/transforms.py:816, in SigmoidTransform.log_abs_det_jacobian(self, x, y, intermediates)
815 def log_abs_det_jacobian(self, x, y, intermediates=None):
--> 816 return -softplus(x) - softplus(-x)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/nn/functions.py:66, in softplus(x)
54 @jax.jit
55 def softplus(x: Array) -> Array:
56 r"""Softplus activation function.
57
58 Computes the element-wise function
(...)
64 x : input array
65 """
---> 66 return jnp.logaddexp(x, 0)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:361, in _logaddexp_jvp(primals, tangents)
359 x1, x2 = primals
360 t1, t2 = tangents
--> 361 x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
362 primal_out = logaddexp(x1, x2)
363 tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
364 lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/util.py:327, in _promote_args_inexact(fun_name, *args)
325 _check_arraylike(fun_name, *args)
326 _check_no_float0s(fun_name, *args)
--> 327 return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/util.py:262, in _promote_dtypes_inexact(*args)
258 def _promote_dtypes_inexact(*args):
259 """Convenience function to apply Numpy argument dtype promotion.
260
261 Promotes arguments to an inexact type."""
--> 262 to_dtype, weak_type = dtypes._lattice_result_type(*args)
263 to_dtype = dtypes.canonicalize_dtype(to_dtype)
264 to_dtype_inexact = _to_inexact_dtype(to_dtype)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/dtypes.py:311, in (.0)
309 N = set(nodes)
310 UB = _lattice_upper_bounds
--> 311 CUB = set.intersection(*(UB[n] for n in N))
312 LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
313 if len(LUB) == 1:
KeyError: dtype([('float0', 'V')])
dtype([('float0', 'V')])
The text was updated successfully, but these errors were encountered: