Debugging JAX Machine Learning and Deep Learning Models and Types of Python Strings
Amine M'Charrak, 01 April 2023
What is the difference between a regular string and an f-string?
In Python, a string is a sequence of characters enclosed within single quotes (‘ ‘) or double quotes (“ “). It can be used to store and manipulate text data. An f-string is a special type of string introduced in Python 3.6 that allows for formatted string literals. It is enclosed within single quotes or double quotes, preceded by the letter ‘f’ (or ‘F’), and contains expressions inside curly braces {} that are evaluated at runtime.
Thus, the main difference between a regular string and an f-string is that a regular string is evaluated at compile-time, while an f-string is evaluated at runtime. This means that when we define a regular string, any expressions or variables inside it are evaluated immediately, whereas expressions inside an f-string are evaluated only when the f-string is formatted.
Example to illustrate the difference between regular string and the ‘fancier’ f-string:
name = "John"
age = 30
# Regular string
str1 = "My name is " + name + " and I am " + str(age) + " years old."
print(str1)
# *f-string*
str2 = f"My name is {name} and I am {age} years old."
print(str2)
Gives the following output:
My name is John and I am 30 years old.
My name is John and I am 30 years old.
As you we can see, both strings produce the same output, but the f-string is more concise and easier to read. In addition, the f-string allows for more complex expressions and formatting options, making it a useful tool for building dynamic strings.
How to use strings in JAX to debug and inspect our Machine Learning and Deep Learning algorithms:
The cool thing is that it is possible to print values to stdout even if the JAX function has been decorated with JIT (or others like pmap and pjit).
To achieve this, we can use the jax.debug
package from JAX.
This package that offers tools to inspect items inside of a JIT-decorated JAX function.
However, f-strings are not suitable for the package’s jax.debug.print()
print function because this JAX specific print function delays formatting.
In this case, a regular string or other formatting options such as the string format() method should be used instead.
To be super clear: When using the jax.debug.print()
function, the fmt
(the first argument of this print function) argument must be a regular string because the formatting operation is delayed until later, when it is transformed and compiled by JAX.
This is different from f-strings, where the formatting operation is done immediately at runtime.
For an example, consider the following f-string:
x = 3
print(f"The value of x is {x}")
In this case, the value of x is immediately formatted and substituted into the string, when the print() function is called.
This is useful when we need to format strings at runtime.
However, when using jax.debug.print()
, you cannot use f-string because the formatting operation is delayed by JAX. For example, this code would not work:
x = 3
jax.debug.print(f"The value of x is {x}")
Instead, to make the above example work, we need to use a regular string or other formatting options such as the string format() method. Here are some examples that will be compatible with JAX’s print function:
# Using the string format() method
x = 3
jax.debug.print("The value of x is {}".format(x))
# Using a regular string
x = 3
jax.debug.print("The value of x is " + str(x))
In both examples (string format() and regular string), the formatting operation will be delayed until later, allowing jax.debug.print()
to transform the string and substitute the value of x at a later time.
All in all, it means that when using JAX to print strings, one should keep the text and data separate.
Conclusion and further resources:
There are many more tools we can use from the jax.debug
package including but not limited to
jax.debug.breakpoint
to pause execution of a compiled JAX function to inspect valuescheckify.checkify
transformation andcheckify.check
function which allows adding functional error checks to JAX code, which is compile-time translated into fast and efficient code
If you want to learn more about debugging JAX code and how to use the JAX debug package, then please see the following resources at https://jax.readthedocs.io/en/latest/debugging/.
That’s it. I hope these examples helped you to (a) understand the difference between Python’s regular string and f-string and (b) how we can use this fact to our advantage when debugging our Machine Learning and Deep Learning models coded in JAX. Thank you and until next time!