Sequential Tinkering - Part 1

../_images/lego-bridge.jpg

Objax offers a significantly enhanced version of the ubiquitous Sequential layer. In this first part of a two-part series, I explore some of its basic, yet powerful, features.

What is a sequential layer anyway?

It is very common in machine learning to run a sequence of operations where one layer’s output is used as input for the next. And that’s pretty much what a sequential layer does: it’s a sequence of modules to execute one after the other by feeding one’s output into the next.

Before going into Objax’s Sequential module, let’s take a look at some other frameworks.

Keras

Keras’ Sequential class implements a standard interface. With the add method one can add layers to a sequence, then build a model and execute it.

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(8, input_shape=(16,)))
model.add(tf.keras.layers.Dense(4))
model.build((None, 16))

PyTorch

PyTorch’s Sequential module offers a similar experience with one main difference: it allows to use a list (passed as *args*) or an ordered dictionary for custom naming of layers.

model = nn.Sequential(nn.Conv2d(1,20,5),
                      nn.ReLU(),
                      nn.Conv2d(20,64,5),
                      nn.ReLU())

Objax Sequential module

Objax’s Sequential module is basically a Python list. As such, it can make quite a few of things natural from a Python perspective. It can do what the previously seen Sequential implementations can do, but thanks to its Pythonic design it can do a few more things easily. Objax’ by design supports general custom naming, so, unlike PyTorch, there’s no need for OrderedDict support in Sequential.

Let’s go in depth through the following features:

  • Mixing modules and functions.

  • Benefits of list syntax.

  • Argument passing to modules in a sequence.

Basics

In its most basic form, a Sequential layer is a list of modules and functions. One difference from PyTorch is that functions in Objax don’t need to be provided as modules to be used in Sequential: they can be used as is, resulting in less API duplication.

You can use the standard Python functools.partial API to modify the default arguments of a function. Here’s an example using leaky_relu:

from functools import partial
import objax

model = objax.nn.Sequential([
    objax.nn.Conv2D(1, 20, 5),
    objax.functional.relu,
    objax.nn.Conv2D(20, 64, 5),
    partial(objax.functional.leaky_relu, negative_slope=0.1)])

print(model.vars())
# (Sequential)[0](Conv2D).b       20 (20, 1, 1)
# (Sequential)[0](Conv2D).w      500 (5, 5, 1, 20)
# (Sequential)[2](Conv2D).b       64 (64, 1, 1)
# (Sequential)[2](Conv2D).w    32000 (5, 5, 20, 64)
# +Total(4)                    32584

Benefits of the list syntax

Since Sequential is a list, one can easily inspect a layer or replace it using standard Python syntax:

print(model[0]).b.value

# Update the value of the first convolution bias
model[0].b.assign(model[0].b.value + 1)

# Replace the second convolution
model[2] = objax.nn.Conv2D(20, 32, 3)

print(model.vars())
# (Sequential)[0](Conv2D).b       20 (20, 1, 1)
# (Sequential)[0](Conv2D).w      500 (5, 5, 1, 20)
# (Sequential)[2](Conv2D).b       32 (32, 1, 1)
# (Sequential)[2](Conv2D).w     5760 (3, 3, 20, 32)
# +Total(4)                     6312

Just like for a regular list, one can also append, extend, insert into Sequential:

model.append(objax.functional.relu)
model.insert(-1, objax.nn.Conv2D(32, 8, 3))
print(len(model))
# 6
print(model.vars())
# (Sequential)[0](Conv2D).b       20 (20, 1, 1)
# (Sequential)[0](Conv2D).w      500 (5, 5, 1, 20)
# (Sequential)[2](Conv2D).b       32 (32, 1, 1)
# (Sequential)[2](Conv2D).w     5760 (3, 3, 20, 32)
# (Sequential)[4](Conv2D).b        8 (8, 1, 1)
# (Sequential)[4](Conv2D).w     2304 (3, 3, 32, 8)
# +Total(6)                     8624

Argument passing to modules in Sequential

Sometimes, one may want to customize a function call, or some module __call__ method that takes extra arguments. Sequential handles them automatically.

import jax.numpy as jn
import objax

model = objax.nn.Sequential([
    objax.nn.Linear(1, 20),
    lambda x, negative_slope=0.01: jn.maximum(x, x * negative_slope),
    objax.nn.Linear(20, 1),
    objax.functional.leaky_relu  # Built-in leaky_relu also uses "negative_slope"
])

# Create some mock input
x = jn.arange(1, 6).reshape((5, 1))

print(model(x).flatten())
# [0.08871414 0.17742828 0.26614243 0.35485655 0.4435708 ]

# Now let's set the negative_slope argument for leaky_relu
print(model(x, negative_slope=0.2).flatten())
# [0.16334192 0.32668385 0.49002576 0.6533677  0.8167097 ]

The same concepts apply to module arguments. For example, we can make a modified Linear module that takes an extra argument.

import jax.numpy as jn
import objax
from objax.typing import JaxArray

class MyModule(objax.nn.Linear):
    # We add an extra float argument called offset.
    # We also add kwargs support for a scale argument.
    def __call__(self, x: JaxArray, offset: float, **kwargs) -> JaxArray:
        y = jn.dot(x, self.w.value)
        if self.b:
            y += self.b.value
        if 'scale' in kwargs:
            y *= kwargs['scale']
        return y + offset

model = objax.nn.Sequential([
    MyModule(1, 20), objax.functional.leaky_relu,
    MyModule(20, 1), objax.functional.leaky_relu])

# Create some mock input
x = jn.arange(1, 6).reshape((5, 1))

print(model(x, offset=0).flatten())
# [0.08871414 0.17742828 0.26614243 0.35485655 0.4435708 ]

print(model(x, offset=1).flatten())
# [-0.00090197  0.24507368  0.31848168  0.47458148  0.6344541 ]

print(model(x, offset=1, scale=1).flatten())
# [-0.00090197  0.24507368  0.31848168  0.47458148  0.6344541 ]
# Note: as expected scale=1 doesn't change the return value

print(model(x, offset=1, scale=0).flatten())
# [1. 1. 1. 1. 1.]
# Note: scale=0 multiplies the layer output by 0, thus only the offset is left.

print(model(x, offset=1, negative_slope=0.2).flatten())
# [-0.01803935  0.27238262  0.42335594  0.64115095  0.8619946 ]
# And like before, you can send arguments to leaky_relu as well.

Conclusion

Objax’ Sequential design as a list simplifies its manipulation by following Python’s behavior. It transparently supports passing arguments to the functions and modules that use them (see example above).

In the next part of this series, we’ll look into Sequential’s advanced features that can be useful for complex neural networks, for representation learning to create embeddings from a classifier easily, and more.