Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Examples failed #24

Closed
blattnem opened this issue Jun 23, 2022 · 13 comments
Closed

Examples failed #24

blattnem opened this issue Jun 23, 2022 · 13 comments
Assignees

Comments

@blattnem
Copy link

blattnem commented Jun 23, 2022

I installed all requirements in a env. Running the https://github.com/google/lightweight_mmm/blob/main/examples/simple_end_to_end_demo.ipynb example fails executing mmm.fit(.....)

177 mcmc.run(
178 rng_key=jax.random.PRNGKey(seed),
179 media_data=jnp.array(media),
180 extra_features=extra_features,
181 target_data=jnp.array(target),
182 cost_prior=jnp.array(total_costs),
183 degrees_seasonality=degrees_seasonality,
184 frequency=seasonality_frequency,
185 transform_function=self._model_transform_function,
186 weekday_seasonality=weekday_seasonality)
188 if media_names is not None:
189 self.media_names = media_names

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:597, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
595 else:
596 if self.chain_method == "sequential":
--> 597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
599 states, last_state = pmap(partial_map_fn)(map_args)

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:160, in _laxmap(f, xs)
158 for i in range(n):
159 x = jit(_get_value_from_index)(xs, i)
--> 160 ys.append(f(x))
162 return tree_map(lambda *args: jnp.stack(args), *ys)

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
701 # vectorized
702 else:
703 rng_key, rng_key_init_model = jnp.swapaxes(
704 vmap(random.split)(rng_key), 0, 1
705 )
--> 706 init_params = self._init_state(
707 rng_key_init_model, model_args, model_kwargs, init_params
708 )
709 if self._potential_fn and init_params is None:
710 raise ValueError(
711 "Valid value of init_params must be provided with" " potential_fn."
712 )

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
--> 652 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
653 rng_key,
654 self._model,
655 dynamic_args=True,
656 init_strategy=self._init_strategy,
657 model_args=model_args,
658 model_kwargs=model_kwargs,
659 forward_mode_differentiation=self._forward_mode_differentiation,
660 )
661 if self._init_fn is None:
662 self._init_fn, self._sample_fn = hmc(
663 potential_fn_gen=potential_fn,
664 kinetic_fn=self._kinetic_fn,
665 algo=self._algo,
666 )

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
652 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 654 (init_params, pe, grad), is_valid = find_valid_initial_params(
655 rng_key,
656 substitute(
657 model,
658 data={
659 k: site["value"]
660 for k, site in model_trace.items()
661 if site["type"] in ["plate"]
662 },
663 ),
664 init_strategy=init_strategy,
665 enum=has_enumerate_support,
666 model_args=model_args,
667 model_kwargs=model_kwargs,
668 prototype_params=prototype_params,
669 forward_mode_differentiation=forward_mode_differentiation,
670 validate_grad=validate_grad,
671 )
673 if not_jax_tracer(is_valid):
674 if device_get(~jnp.all(is_valid)):

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
393 # Handle possible vectorization
394 if rng_key.ndim == 1:
--> 395 (init_params, pe, z_grad), is_valid = _find_valid_params(
396 rng_key, exit_early=True
397 )
398 else:
399 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.._find_valid_params(rng_key, exit_early)
377 init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False)
378 if exit_early and not_jax_tracer(rng_key):
379 # Early return if valid params found. This is only helpful for single chain,
380 # where we can avoid compiling body_fn in while_loop.
--> 381 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
382 if not_jax_tracer(is_valid):
383 if device_get(is_valid):

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params..body_fn(state)
364 z_grad = jacfwd(potential_fn)(params)
365 else:
--> 366 pe, z_grad = value_and_grad(potential_fn)(params)
367 z_grad_flat = ravel_pytree(z_grad)[0]
368 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

[... skipping hidden 8 frame]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
244 substituted_model = substitute(
245 model, substitute_fn=partial(unconstrain_reparam, params)
246 )
247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density
(
249 substituted_model, model_args, model_kwargs, {}
250 )
251 return -log_joint

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params)
50 """
51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
52 latent values params.
(...)
59 :return: log of joint density and a corresponding model trace
60 """
61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
63 log_joint = jnp.zeros(())
64 for site in model_trace.values():

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: OrderedDict containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)

[... skipping similar frames: Messenger.__call__ at line 105 (2 times)]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/lightweight_mmm/models.py:187, in media_mix_model(media_data, target_data, cost_prior, degrees_seasonality, frequency, transform_function, transform_kwargs, weekday_seasonality, extra_features)
182 with numpyro.plate(name="beta_trend_plate", size=n_geos):
183 beta_trend = numpyro.sample(
184 name="beta_trend",
185 fn=dist.Normal(loc=0., scale=1.))
--> 187 expo_trend = numpyro.sample(
188 name="expo_trend",
189 fn=dist.Beta(concentration1=1., concentration0=1.))
191 with numpyro.plate(
192 name="channel_media_plate",
193 size=n_channels,
194 dim=-2 if media_data.ndim == 3 else -1):
195 beta_media = numpyro.sample(
196 name="channel_beta_media" if media_data.ndim == 3 else "beta_media",
197 fn=dist.HalfNormal(scale=cost_prior))

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:219, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
204 initial_msg = {
205 "type": "sample",
206 "name": name,
(...)
215 "infer": {} if infer is None else infer,
216 }
218 # ...and use apply_stack to send it to the Messengers
--> 219 msg = apply_stack(initial_msg)
220 return msg["value"]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the "stop" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
50 if msg.get("stop"):

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/handlers.py:789, in substitute.process_message(self, msg)
787 value = self.data.get(msg["name"])
788 else:
--> 789 value = self.substitute_fn(msg)
791 if value is not None:
792 msg["value"] = value

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:216, in _unconstrain_reparam(params, site)
213 return p
214 value = t(p)
--> 216 log_det = t.log_abs_det_jacobian(p, value)
217 log_det = sum_rightmost(
218 log_det, jnp.ndim(log_det) - jnp.ndim(value) + len(site["fn"].event_shape)
219 )
220 if site["scale"] is not None:

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/distributions/transforms.py:816, in SigmoidTransform.log_abs_det_jacobian(self, x, y, intermediates)
815 def log_abs_det_jacobian(self, x, y, intermediates=None):
--> 816 return -softplus(x) - softplus(-x)

[... skipping hidden 20 frame]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/nn/functions.py:66, in softplus(x)
54 @jax.jit
55 def softplus(x: Array) -> Array:
56 r"""Softplus activation function.
57
58 Computes the element-wise function
(...)
64 x : input array
65 """
---> 66 return jnp.logaddexp(x, 0)

[... skipping hidden 5 frame]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:361, in _logaddexp_jvp(primals, tangents)
359 x1, x2 = primals
360 t1, t2 = tangents
--> 361 x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
362 primal_out = logaddexp(x1, x2)
363 tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
364 lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/util.py:327, in _promote_args_inexact(fun_name, *args)
325 _check_arraylike(fun_name, *args)
326 _check_no_float0s(fun_name, *args)
--> 327 return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/util.py:262, in _promote_dtypes_inexact(*args)
258 def _promote_dtypes_inexact(*args):
259 """Convenience function to apply Numpy argument dtype promotion.
260
261 Promotes arguments to an inexact type."""
--> 262 to_dtype, weak_type = dtypes._lattice_result_type(*args)
263 to_dtype = dtypes.canonicalize_dtype(to_dtype)
264 to_dtype_inexact = _to_inexact_dtype(to_dtype)

[... skipping hidden 2 frame]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/dtypes.py:311, in (.0)
309 N = set(nodes)
310 UB = _lattice_upper_bounds
--> 311 CUB = set.intersection(*(UB[n] for n in N))
312 LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
313 if len(LUB) == 1:

KeyError: dtype([('float0', 'V')])

dtype([('float0', 'V')])

@pabloduque0
Copy link
Collaborator

pabloduque0 commented Jun 23, 2022

hey @blattnem ! Thanks for reporting this one.

I might need a bit more info on your enviroment.

  1. What OS and version are you running?
  2. What version of lightweight_mmm are you using?
  3. Are you running the example as is? or with your own data?

Thanks!

@pabloduque0 pabloduque0 self-assigned this Jun 23, 2022
@blattnem
Copy link
Author

  1. Mac OS Big Sur 11.5.2
  2. lightweight_mmm 0.1.3
  3. Run example as is

@pabloduque0
Copy link
Collaborator

The JAX support for Mac systems is not very extended, it actually depends on the underlying hardware to my understanding. For some ARM based CPU's might be good but I am not sure about sillicon based ones. Also this would be for CPU's but might be more tricky if you are trying GPU.

Any chance you have used JAX on that system for something else? Has that worked fine?

I think you might need to try to install and get working JAX separately and then go into lightweight_mmm.

My 2cents: Generally Linux would be a much less painful setup so if you have access to one I would recommend trying it there. Otherwise you can always use Colaboratory which is free and should give all you need.

@blattnem
Copy link
Author

blattnem commented Jun 24, 2022

Hi Pablo
Sorry, same issue with linux (Manjaro linux).
I had to install python 3.8 to install everything
I get the same error as documented yesterday.
Best,
Marcel

@blattnem
Copy link
Author

I install lightweight_mmm with pip

@pabloduque0
Copy link
Collaborator

Let me try to reproduce and get back to you.

When you said you had to install python 3.8, what version were you using before?

I have a couple setups with 3.7 and 3.9 where things seem to be running fine, but we definitely want to support 3.8 as well so I will look into it.

@blattnem
Copy link
Author

I have python 3.10 on my linux box. Installed 3.8 because the dependency on tensorflow 2.7.2.

@blattnem
Copy link
Author

This are my libs when running your example notebooks


jax 0.3.13
jaxlib 0.3.10
lightweight_mmm 0.1.3
numpyro 0.9.2
session_info 1.0.0

Click to view modules imported as dependencies


IPython 8.4.0
jupyter_client 7.3.4
jupyter_core 4.10.0
notebook 6.4.12

Python 3.8.13 (default, Jun 24 2022, 11:47:12) [GCC 11.2.0]
Linux-5.10.109-1-MANJARO-x86_64-with-glibc2.34

Session information updated at 2022-06-28 09:07

I got exactly the same error. I am not sure that it relates to python 3.8 version

@blattnem
Copy link
Author

blattnem commented Jun 28, 2022

...and the same with python 3.9. I guess its a jax issue. What jax version are you using?

@pabloduque0
Copy link
Collaborator

Okay, thank you for that. Let me dig deeper into this one and get back to you.

@usul83
Copy link

usul83 commented Jun 29, 2022

Looks like this is fixed by jax-0.3.14 & jaxlib-0.3.14 (I was having the same issue yesterday with the versions above)

@pabloduque0
Copy link
Collaborator

@usul83 good to know that upgrade worked! Thanks for pointing it out!

Im running our test with such versions again to confirm that is solved across the board.

@blattnem can you confirm it also solves your case?

@blattnem
Copy link
Author

blattnem commented Jul 1, 2022

Hi there. Great news! Jax 0.3.14 does the job! Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants