Skip to content

Commit

Permalink
[JAX] Update JAX users in preparation for a change that makes iterati…
Browse files Browse the repository at this point in the history
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406860111
  • Loading branch information
Language Team authored and kentonl committed Nov 13, 2021
1 parent aa58066 commit 171436f
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_attention_layer(self):

# Check input was not changed where it should not be
all_indices = set(
itertools.product(jnp.arange(self.bsz), jnp.arange(self.seq_len)))
itertools.product(np.arange(self.bsz), np.arange(self.seq_len)))
start_indices = set(
zip(self.mention_batch_positions, self.mention_start_positions))
non_start_indices = all_indices.difference(start_indices)
Expand Down
5 changes: 3 additions & 2 deletions language/mentionmemory/modules/memory_attention_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,11 @@ def test_mention_memory_layer(self, separate_memory_values):

# Check input was not changed where it should not be
all_indices = set(
itertools.product(jnp.arange(self.bsz), jnp.arange(self.seq_len)))
itertools.product(np.arange(self.bsz), np.arange(self.seq_len)))
# Note that mention positions is the same across all of the devices
start_indices = set(
zip(mention_batch_positions[0], mention_start_positions[0]))
zip(mention_batch_positions[0].tolist(),
mention_start_positions[0].tolist()))
non_start_indices = all_indices.difference(start_indices)
non_start_indices_1, non_start_indices_2 = zip(*non_start_indices)
non_start_indices_1 = jnp.asarray(non_start_indices_1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_attention_layer(self, retrieval_update_type,

# Check input was not changed where it should not be
all_indices = set(
itertools.product(jnp.arange(self.bsz), jnp.arange(self.seq_len)))
itertools.product(np.arange(self.bsz), np.arange(self.seq_len)))
start_indices = set(
zip(self.mention_batch_positions, self.mention_start_positions))
non_start_indices = all_indices.difference(start_indices)
Expand Down
7 changes: 3 additions & 4 deletions language/mentionmemory/tasks/mention_based_entity_qa_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ def get_predictions_sum(attention_weights: Array, memory_entity_ids: Array,
n_mentions = attention_weights.shape[0]
attention_weights_per_entity = jnp.zeros((n_mentions, entity_vocab_size),
dtype=attention_weights.dtype)
attention_weights_per_entity = jax.ops.index_add(
attention_weights_per_entity,
(jnp.expand_dims(jnp.arange(n_mentions), 1), memory_entity_ids),
attention_weights)
attention_weights_per_entity = attention_weights_per_entity.at[
jnp.expand_dims(jnp.arange(n_mentions), 1),
memory_entity_ids].add(attention_weights)
predictions = jnp.argmax(attention_weights_per_entity, axis=1)
predictions = predictions * weights
return predictions
Expand Down
4 changes: 2 additions & 2 deletions language/mentionmemory/tasks/mention_memory_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def test_same_entity_set_retrieval_loss(self,
sample_weights = mention_target_weights[mention_target_batch_positions
== batch_index]
sample_ids = sample_ids[sample_weights > 0]
sample_ids = set([x for x in sample_ids if x != 0])
sample_ids = set([x for x in sample_ids.tolist() if x != 0])

for m_index in range(n_mentions_per_local_batch):
if mention_batch_positions[m_index] != batch_index:
Expand All @@ -377,7 +377,7 @@ def test_same_entity_set_retrieval_loss(self,
n_correct_retrievals, n_incorrect_retrievals = 0, 0
for r_index in range(n_retrievals):
common_ids = set(
memory_text_entities[r_index]).intersection(sample_ids)
memory_text_entities[r_index].tolist()).intersection(sample_ids)
num_commons[m_index, r_index] = len(common_ids)
if len(common_ids) >= config.same_entity_set_target_threshold:
n_correct_retrievals += 1
Expand Down
2 changes: 1 addition & 1 deletion language/mentionmemory/utils/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def vmap_slice(array: Array, indices: Array) -> Array:
@jax.vmap
def vmap_index_add(array: Array, indices: Array, values: Array) -> Array:
"""Convenience function for index add that differs along first dimension."""
return jax.ops.index_add(array, indices, values)
return array.at[indices].add(values)


def cosine_similarity(a: Array, b: Array) -> Array:
Expand Down
6 changes: 6 additions & 0 deletions language/nql/nql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,17 @@ def get_config(self):
# Allow TF deserialization.
@classmethod
def from_config(cls, config):
"""Recreate a NeuralQueryContext from a saved config."""
if 'np_initval' in config:
config['np_initval'] = {
rel: NeuralQueryContext._dict_to_sparse(matrix_dict)
for (rel, matrix_dict) in config['np_initval'].items()
}
if 'symtab' in config:
symtab = dict()
for (k, v) in config['symtab']:
symtab[k] = symbol.create_from_dict(v)
config['symtab'] = symtab
return cls(**config)

# Basic API
Expand Down
4 changes: 2 additions & 2 deletions language/search_agents/muzero/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tensorflow.core.protobuf import trackable_object_graph_pb2 # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
from official.modeling import tf_utils
from official.nlp import keras_nlp
from official.nlp import modeling
from official.nlp.bert import configs
from official.nlp.modeling import layers

Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self,
word_embeddings = self._embedding_layer(word_ids)

# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
self._position_embedding_layer = modeling.layers.PositionEmbedding(
initializer=initializer, max_length=max_sequence_length)
position_embeddings = self._position_embedding_layer(word_embeddings)
all_embeddings = [word_embeddings, position_embeddings]
Expand Down

0 comments on commit 171436f

Please sign in to comment.