PyTorch steps up to help handling larger than life models

PyTorch steps up to help handling larger than life models

PyTorch, a Python library for tensor computation and neural networks, is now available in version 1.4, bringing experimental distributed model parallel training and Java bindings amongst other things.

PyTorch is a project written in a combination of Python, C++, and CUDA which was mainly developed in Facebook’s AI research lab. It has shared a repository with deep learning framework Caffe2 since 2018 and is one of the main competitors to Google’s TensorFlow.

In PyTorch 1.4, distributed model parallel training has been added to accommodate the growing scale and complexity of modern models. In the case of Facebook’s RoBERTa method, the amount of parameters to take into account can be up in the billions, which not all machines can handle.

The remote procedure call framework which is used to implement distributed training facilitates such scenarios by helping to split large models onto multiple machines and getting rid of the need to copy data back and forth during training. It instead allows devs to reference distant objects and run functions remotely, which makes it also useful for inference. A tutorial to help users get into the nitty gritty of this still experimental feature can be found on the PyTorch website.

But it’s not only in the training phase that a models‘ expanse can get problematic, if you think about the variety of devices that use machine learning nowadays. This is where pruning, another new addition to PyTorch, comes into play. Version 1.4 adds a variety of techniques to reduce the size and complexity of a model to not only let it run in more resource restrictive environments but also become faster without getting terrible results (at least if you strike the right balance).

The project’s just in time compiler now has a clone_instance function for ScriptModule, and allows the use of module containers as iterables as well as the latter’s use in list comprehensions. Dictionaries in the JIT have been changed to preserve insertion order, while the calls to submodules are finally stored in the traced graph. 

Apart from that, the PyTorch team has upgraded mobile support by adding a way to only include operators used by a given model during the build process, which reduces the resources a library needs on a device which can be quite helpful. Java devs who use Linux and want to look into deep learning can now try to invoke TorchScript models from their programs, since a first version of Java support has been added as well. The update also lets users define multiple schedulers in torch.optim that can then be used in combination without overwriting each other.

More details can be found in the release notes. Since PyTorch 1.4 comes with a number of backwards incompatible changes, they are well worth studying to make sure old programs keep running as intended.