Jax Apply Function Only On Slice Of Array Under Jit
Solution 1:
The previous answer by @rvinas using dynamic_slice
works well if your index is static, but you can also accomplish this with a dynamic index using jnp.where
. For example:
import jax
import jax.numpy as jnp
defother_fun(x):
return x + 1@jax.jitdeffun(x, index):
mask = jnp.arange(x.shape[0]) < index
return jnp.where(mask, other_fun(x), x)
x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]
Solution 2:
It seems there are two issues in your implementation. First, the slices are producing dynamically shaped arrays (not allowed in jitted code). Second, unlike numpy arrays, JAX arrays are immutable (i.e. the contents of the array cannot be changed).
You can overcome the two problems by combining static_argnums
and jax.lax.dynamic_update_slice
. Here is an example:
def other_fun(x):
return x + 1@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
update = other_fun(x[:index])
return jax.lax.dynamic_update_slice(x, update, (0,))
x = jnp.arange(5)
print(fun(x, 3)) # prints [12334]
Essentially, the example above uses static_argnums
to indicate that the function should be recompiled for different index
values and jax.lax.dynamic_update_slice
creates a copy of x
with updated values at :len(update)
.
Post a Comment for "Jax Apply Function Only On Slice Of Array Under Jit"