This is the 11th day of my participation in the August More text Challenge. For details, see: August More Text Challenge

Basically, JAX is a library that provides an API similar to NumPy, primarily for conversions by written array manipulators. Some people even think that JAX can be regarded as Numpy V2, which not only speeds up Numpy but also provides automatic derivative (Grad) function for Numpy, so that we can implement a machine learning framework only by relying on JAX.

The next big thing is to explain why the JAX-provided API is similar to NumPy. Now, you can think of JAX as running NumPy with automatic derivative support on an accelerator.

import jax
import jax.numpy as jnp

x = jnp.arange(10)
print(x)
Copy the code

If you’re familiar with or have written with NUMpy, the code above should be familiar, and that’s the beauty of JAX. The seamless transition from Numpy to JAX is that you don’t have to learn a new API. You can use the same code you used to implement with NUMPY, you can use JNP instead of NP, and you can run the program, although there are some differences, which will be discussed later. JNP is a variable of type DeviceArray, which is how JAX represents arrays.

We will now compute the dot product of the two vectors, block_until_ready, running the code on the GPU’s device without changing the code. Use %timeit to check performance.

Technical details: When a JAX function is called, the corresponding operation is dispatched to an accelerator, which is computed asynchronously. Therefore, the array returned by the calculation is not necessarily “filled” when the function returns. Therefore, if you don’t need an immediate result, you won’t block Python execution because the calculation is asynchronous. Therefore, unless block_until_ready is set, we will only time the schedule, not the actual calculation. See the JAX documentation for asynchronous scheduling

long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()
Copy the code
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 6.37 ms per loop
Copy the code

The first jax-conversion: Grad

One of the basic features of JAX is that it allows conversion functions. One of the most common conversions is JAX.grad, which takes a numeric function written in Python and returns a new Python function that calculates the gradient of the original function. Define a function sum_of_squares that receives an array and returns the sum of each element squared in the array.

def sum_of_squares(x) :
  return jnp.sum(x**2)
Copy the code

Applying JAX. grad to sum_of_squares returns a different function, which is the gradient of sum_of_squares with respect to its first argument, x.

Then, we input the array into this derivative function to return the derivative with respect to each element in the array.

sum_of_squares_dx = jax.grad(sum_of_squares)

x = jnp.asarray([1.0.2.0.3.0.4.0])

print(sum_of_squares(x))

print(sum_of_squares_dx(x))
Copy the code
0.0
[2. 4. 6. 8.]
Copy the code

If the function f(x)f(x)f(x) f(x) is input to jax. Grad, it is equivalent to the function that returns the nablanablanabla function used to calculate the gradient of 𝑓.


( βˆ‡ f ) ( x i ) = partial f partial x i ( x i ) (\nabla f)(x_i) = \frac{\partial f}{\partial x_i}(x_i)

Similarly, jax.grad(f) is a function that calculates the gradient, so jax.grad(f)(x) is the gradient of F at x. (As with βˆ‡ nablaβˆ‡, JAX. grad only works on functions that have scalar output, otherwise it will cause errors)

This makes the JAXAPI very different from other deep learning frameworks that support automatic derivation, such as Tensorflow and PyTorch, where you can use the loss tensor itself to calculate the gradient (for example, by calling lose.backward ()). The JAXAPI works directly with functions and is closer to the underlying mathematics. Once you get used to doing things this way, it feels natural: you lose a function in code that is really a function of parameters and data, and you find its gradient as you would in mathematics.

This way of doing things makes it easy to control things like which variables to differentiate. By default, JAX. Grad finds the gradient associated with the first parameter. In the following example, the result of sum_squared_error_dx will be the gradient of sum_squared_error with respect to x.

def sum_squared_error(x, y) :
  return jnp.sum((x-y)**2)

sum_squared_error_dx = jax.grad(sum_squared_error)

y = jnp.asarray([1.1.2.1.3.1.4.1])

print(sum_squared_error_dx(x, y))
Copy the code

If you need to calculate the gradient of different parameters (or several parameters), you can set argNUMs to do so.

[-0.20000005 -0.19999981 -0.19999981 -0.19999981]
Copy the code
jax.grad(sum_squared_error, argnums=(0.1))(x, y)  # Find gradient wrt both x & y
Copy the code
(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
 DeviceArray([0.20000005.0.19999981.0.19999981.0.19999981], dtype=float32))
Copy the code

Does this mean that when doing machine learning, models need to write functions with huge lists of parameters, one for each array of model parameters? JAX has a mechanism to bundle arrays into data structures called “Pytrees”, and jax.Grad is used like this.

The Value and Grad

jax.value_and_grad(sum_squared_error)(x, y)
Copy the code
(DeviceArray(0.03999995, dtype=float32),
 DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))
Copy the code

Secondary data

In addition to wanting to record values, we often want to report some intermediate results obtained when calculating the loss function. But if we try to do this with plain old JAX. Grad, we run into trouble.

def squared_error_with_aux(x, y) :
  return sum_squared_error(x, y), x-y

jax.grad(squared_error_with_aux)(x, y)
Copy the code

The above code will report an error, and you need to set a parameter in the grad function.

jax.grad(squared_error_with_aux, has_aux=True)(x, y)
Copy the code

This is because JAX. grad is defined only on scalar functions, and the resulting function returns a tuple. Because the team contains some auxiliary data, this is where Has_AUX comes in.

JAX is different from NumPy

As you can see from the example above, JAX.numpy has essentially the same API design as NUMPY. However, not all of them have some differences. Next, let’s introduce the differences between JAX and Numpy. The most important difference is JAX’s more functional programming style, which is the main reason why Numpy and JAX are not just the same at some point. An introduction to functional programming (FP) is outside the scope of this guide. If you are already familiar with FP, you will feel more comfortable using JAX, which is designed for functional programming.

import numpy as np

x = np.array([1.2.3])

def in_place_modify(x) :
  x[0] = 123
  return None

in_place_modify(x)
x
Copy the code

If you’re familiar with functional programming and you see the problem when you see the output array([123, 2, 3]), in_place_modify does something side-effect, updating the value of x inside it. Because in functional programming, data is supposed to be immutable. Instead of making changes to the source, data is copied.

in_place_modify(jnp.array(x)
Copy the code

Helpfully, this error points out that jax.ops. Index_ * ops is a side-effect-free method. Similar to in-place modification that should not be performed on the original array by indexing, a new array is created and modified accordingly. So the above operation will report an error in JAX

def jax_in_place_modify(x) :
  return jax.ops.index_update(x, 0.123)

y = jnp.array([1.2.3])
jax_in_place_modify(y)
Copy the code
DeviceArray([123,   2,   3], dtype=int32)
Copy the code

And then we look at y again and we see that it hasn’t changed.

y #DeviceArray([1, 2, 3], dtype=int32)
Copy the code

Side-effect-free code is sometimes called functionally pure, or just pure.

Code with no side effects is sometimes referred to as functional pure, not simply functional, but not doing anything like updating app status, IO, etc.

Isn’t the Pure version less efficient? Strictly speaking, yes. This is where instead of making changes to the existing data we create a new array and make changes to it. However, JAX calculations are usually compiled before they run using another program transformation, namely JAX.JIT. If we don’t use it after the jax.ops.index_update() “in-place” modification to the original array, the compiler recognizes that it can actually compile to an in-place modification, resulting in efficient code.

Of course, it’s possible to mix Python code that has side effects with JAX code that supports function saving, and it’s hard to write or almost impossible to write purely functional programming, and as you get more and more familiar with JAX, you’ll get better at knowing when to use JAX, and we’ll talk about that later, For the moment, let’s just remember to avoid side effects in JAX.