Sequential Tinkering - Part 1¶
Objax offers a significantly enhanced version of the ubiquitous Sequential
layer. In this first part of a two-part series, I explore some of its basic, yet powerful, features.
What is a sequential layer anyway?¶
It is very common in machine learning to run a sequence of operations where one layer’s output is used as input for the next. And that’s pretty much what a sequential layer does: it’s a sequence of modules to execute one after the other by feeding one’s output into the next.
Before going into Objax’s Sequential module, let’s take a look at some other frameworks.
Keras¶
Keras’ Sequential class implements a standard interface.
With the add
method one can add layers to a sequence, then build
a model and execute it.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(8, input_shape=(16,)))
model.add(tf.keras.layers.Dense(4))
model.build((None, 16))
PyTorch¶
PyTorch’s Sequential module offers a similar experience with one main difference: it allows to use a list (passed as *args*
) or an ordered dictionary for custom naming of layers.
model = nn.Sequential(nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU())
Objax Sequential module¶
Objax’s Sequential module is basically a Python list.
As such, it can make quite a few of things natural from a Python perspective.
It can do what the previously seen Sequential
implementations can do, but thanks to its Pythonic design it can do a few more things easily.
Objax’ by design supports general custom naming, so, unlike PyTorch, there’s no need for OrderedDict
support in Sequential.
Let’s go in depth through the following features:
Mixing modules and functions.
Benefits of list syntax.
Argument passing to modules in a sequence.
Basics¶
In its most basic form, a Sequential layer is a list of modules and functions. One difference from PyTorch is that functions in Objax don’t need to be provided as modules to be used in Sequential: they can be used as is, resulting in less API duplication.
You can use the standard Python functools.partial
API to modify the default arguments of a function.
Here’s an example using leaky_relu
:
from functools import partial
import objax
model = objax.nn.Sequential([
objax.nn.Conv2D(1, 20, 5),
objax.functional.relu,
objax.nn.Conv2D(20, 64, 5),
partial(objax.functional.leaky_relu, negative_slope=0.1)])
print(model.vars())
# (Sequential)[0](Conv2D).b 20 (20, 1, 1)
# (Sequential)[0](Conv2D).w 500 (5, 5, 1, 20)
# (Sequential)[2](Conv2D).b 64 (64, 1, 1)
# (Sequential)[2](Conv2D).w 32000 (5, 5, 20, 64)
# +Total(4) 32584
Benefits of the list syntax¶
Since Sequential is a list, one can easily inspect a layer or replace it using standard Python syntax:
print(model[0]).b.value
# Update the value of the first convolution bias
model[0].b.assign(model[0].b.value + 1)
# Replace the second convolution
model[2] = objax.nn.Conv2D(20, 32, 3)
print(model.vars())
# (Sequential)[0](Conv2D).b 20 (20, 1, 1)
# (Sequential)[0](Conv2D).w 500 (5, 5, 1, 20)
# (Sequential)[2](Conv2D).b 32 (32, 1, 1)
# (Sequential)[2](Conv2D).w 5760 (3, 3, 20, 32)
# +Total(4) 6312
Just like for a regular list, one can also append, extend, insert into Sequential:
model.append(objax.functional.relu)
model.insert(-1, objax.nn.Conv2D(32, 8, 3))
print(len(model))
# 6
print(model.vars())
# (Sequential)[0](Conv2D).b 20 (20, 1, 1)
# (Sequential)[0](Conv2D).w 500 (5, 5, 1, 20)
# (Sequential)[2](Conv2D).b 32 (32, 1, 1)
# (Sequential)[2](Conv2D).w 5760 (3, 3, 20, 32)
# (Sequential)[4](Conv2D).b 8 (8, 1, 1)
# (Sequential)[4](Conv2D).w 2304 (3, 3, 32, 8)
# +Total(6) 8624
Argument passing to modules in Sequential¶
Sometimes, one may want to customize a function call, or some module __call__
method that takes extra arguments.
Sequential handles them automatically.
import jax.numpy as jn
import objax
model = objax.nn.Sequential([
objax.nn.Linear(1, 20),
lambda x, negative_slope=0.01: jn.maximum(x, x * negative_slope),
objax.nn.Linear(20, 1),
objax.functional.leaky_relu # Built-in leaky_relu also uses "negative_slope"
])
# Create some mock input
x = jn.arange(1, 6).reshape((5, 1))
print(model(x).flatten())
# [0.08871414 0.17742828 0.26614243 0.35485655 0.4435708 ]
# Now let's set the negative_slope argument for leaky_relu
print(model(x, negative_slope=0.2).flatten())
# [0.16334192 0.32668385 0.49002576 0.6533677 0.8167097 ]
The same concepts apply to module arguments.
For example, we can make a modified Linear
module that takes an extra argument.
import jax.numpy as jn
import objax
from objax.typing import JaxArray
class MyModule(objax.nn.Linear):
# We add an extra float argument called offset.
# We also add kwargs support for a scale argument.
def __call__(self, x: JaxArray, offset: float, **kwargs) -> JaxArray:
y = jn.dot(x, self.w.value)
if self.b:
y += self.b.value
if 'scale' in kwargs:
y *= kwargs['scale']
return y + offset
model = objax.nn.Sequential([
MyModule(1, 20), objax.functional.leaky_relu,
MyModule(20, 1), objax.functional.leaky_relu])
# Create some mock input
x = jn.arange(1, 6).reshape((5, 1))
print(model(x, offset=0).flatten())
# [0.08871414 0.17742828 0.26614243 0.35485655 0.4435708 ]
print(model(x, offset=1).flatten())
# [-0.00090197 0.24507368 0.31848168 0.47458148 0.6344541 ]
print(model(x, offset=1, scale=1).flatten())
# [-0.00090197 0.24507368 0.31848168 0.47458148 0.6344541 ]
# Note: as expected scale=1 doesn't change the return value
print(model(x, offset=1, scale=0).flatten())
# [1. 1. 1. 1. 1.]
# Note: scale=0 multiplies the layer output by 0, thus only the offset is left.
print(model(x, offset=1, negative_slope=0.2).flatten())
# [-0.01803935 0.27238262 0.42335594 0.64115095 0.8619946 ]
# And like before, you can send arguments to leaky_relu as well.
Conclusion¶
Objax’ Sequential design as a list simplifies its manipulation by following Python’s behavior. It transparently supports passing arguments to the functions and modules that use them (see example above).
In the next part of this series, we’ll look into Sequential’s advanced features that can be useful for complex neural networks, for representation learning to create embeddings from a classifier easily, and more.