Skip to content Skip to sidebar Skip to footer

Jax Apply Function Only On Slice Of Array Under Jit

I am using JAX, and I want to perform an operation like @jax.jit def fun(x, index): x[:index] = other_fun(x[:index]) return x This cannot be performed under jit. Is there

Solution 1:

The previous answer by @rvinas using dynamic_slice works well if your index is static, but you can also accomplish this with a dynamic index using jnp.where. For example:

import jax
import jax.numpy as jnp

defother_fun(x):
    return x + 1@jax.jitdeffun(x, index):
  mask = jnp.arange(x.shape[0]) < index
  return jnp.where(mask, other_fun(x), x)

x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]

Solution 2:

It seems there are two issues in your implementation. First, the slices are producing dynamically shaped arrays (not allowed in jitted code). Second, unlike numpy arrays, JAX arrays are immutable (i.e. the contents of the array cannot be changed).

You can overcome the two problems by combining static_argnums and jax.lax.dynamic_update_slice. Here is an example:

def other_fun(x):
    return x + 1@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
    update = other_fun(x[:index])
    return jax.lax.dynamic_update_slice(x, update, (0,))

x = jnp.arange(5)
print(fun(x, 3))  # prints [12334]

Essentially, the example above uses static_argnums to indicate that the function should be recompiled for different index values and jax.lax.dynamic_update_slice creates a copy of x with updated values at :len(update).

Post a Comment for "Jax Apply Function Only On Slice Of Array Under Jit"