pytorch featured

The Good and Bad of PyTorch Machine Learning Library

While AI may not take over the world anytime soon—at least as far as we know—it has transformed how we live our lives and conduct business. The machine learning market, once valued at $3.871 billion in 2022, is expected to grow to $49.875 billion by 2032—a 32.8 percent compound annual growth rate (CAGR).

These advancements in AI technology would be impossible without machine learning (ML). One of the ML tools that make machine learning possible is PyTorch, a popular machine learning framework.

In this article, we will learn about PyTorch in detail, including how it works, its pros and cons, its applications, and how it compares against other machine learning libraries like TensorFlow and Keras.

What is PyTorch?

PyTorch is an open-source machine learning library for training deep neural networks (DNNs). It was created by the Meta AI research lab in 2016 and released in October of the same year. PyTorch is written in C++ and Python.

The framework has become an alternative to Torch, its predecessor, and grew in popularity because of its ease of use in building ML models and its dynamic computational graphs that allow models to be optimized at runtime. But more on that later.

While it may sound complicated, PyTorch essentially allows you to train machine learning models in a few lines of Python-like code. Just like how a regular toolbox helps carpenters build chairs, tables, etc, PyTorch gives developers the tools to create and fine-tune models for tasks like image recognition, natural language processing, and predictive analysis.

Now that we understand what PyTorch is let’s explore how it works.

How does PyTorch work?

Let’s break down the process of training a neural network with PyTorch.

How PyTorch works

How PyTorch works

Preparing the data. The foundation of any successful neural network lies in the quality and suitability of the input data. PyTorch provides a range of tools and functions to assist with data preprocessing tasks, ensuring your data is in the optimal format for training your model.

The data preparation process involves steps like normalization, data augmentation, and splitting the data into training and test sets.

Converting the data into tensors. Before we can feed data into our neural network, we need to convert it into a representation  that PyTorch can understand — tensors. Tensors are multidimensional arrays that serve as the fundamental data structures in PyTorch.

The conversion is typically done using the torch.tensor function.

Creating datasets and dataloaders. After converting the data into tensors, the next step is to create datasets and dataloaders.

Datasets store data samples and their corresponding labels, while dataloaders iterate over a dataset during training. With the help of a dataloader, you can, for instance, define the sample batch size, specify whether the data must be shuffled to make the model generalize better, etc.  

Defining the neural network. At the heart of any neural network is its architecture, which defines the structure and flow of information through the network and its layers.

In PyTorch, neural networks are defined as Python classes that inherit from the torch.nn.Module class. The class should include the network's layers in the __init__ method and the forward pass logic in the forward method. The forward method defines how the input data passes through these layers to produce the desired output. This method should contain the layers and activation functions that make up the neural network.

Loss function and optimizer. As we train our network, we need a way to quantify how well it's performing and adjust its weights and biases. This is where the loss function and optimizer come in.

The loss function measures how well the network's predictions match the actual labels of the training data. Common loss functions in PyTorch include torch.nn.CrossEntropyLoss for classification tasks and torch.nn.MSELoss for regression tasks.

The optimizer helps minimize the loss by adjusting and updating the network’s internal parameters (weights and biases). Popular PyTorch optimizers include algorithms like stochastic gradient descent (torch.optim.SGD), Adam (torch.optim.Adam), and RMSprop (torch.optim.RMSprop).

Training the model. The training, just like with any other tool, involves feeding the neural network batches of data, allowing it to make predictions, and then using the loss function and a selected optimization algorithm to adjust its internal parameters based on how well it performed. PyTorch will report on the loss results after every 1000 batches.

Evaluating the model. After training the model, it’s important to evaluate its performance on a separate test subset of data. This helps to understand how well the model performs on new, unseen data.

Passing the test data through the trained model and comparing its predictions to the correct answers helps measure how well the neural network has learned and helps calculate metrics like accuracy, precision, and recall.

Components and features of PyTorch

Tensors. Tensors are the fundamental data structures used in PyTorch. They are multidimensional arrays that can be used to store data and act as building blocks of PyTorch’s computational processes,helping developers manipulate a model’s inputs, outputs, and parameters. While tensors help represent complex multidimensional data, you can use them for data of any type and size.

Besides, tensors are the best fit for fast GPU-based computation. PyTorch supports parallel computing with platforms like NVIDIA’s CUDA.

What is a tensor in machine learning?

What is a tensor in machine learning?

As we mentioned, tensors in machine learning are normally used to represent complex data as multidimensional arrays, but with PyTorch as well as TensorFlow, you can express any kind of data as a tensor:

  • Scalars: zero-dimensional arrays that carry a single number. Example: torch.tensor(3.14) creates a 0D tensor (scalar) holding the value 3.14
  • Vectors: one-dimensional arrays that carry multiple scalars of the same type. Example: torch.tensor([1.0, 2.0, 3.0]) creates a 1D tensor (vector) with three elements
  • Matrices: two-dimensional arrays that carry vectors of the same type. Example: torch.tensor([[1, 2], [3, 4]]) creates a 2D tensor (matrix) with two rows and two columns
  • Multidimensional arrays: they extend beyond two dimensions. Example: torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) creates a 3D tensor

Modules. Modules are building blocks for creating neural networks. You can think of them as self-contained units that perform specific tasks. A module can contain other modules, parameters, and functions.

Combining multiple modules allows you to create neural network architectures for tasks like image recognition, natural language processing, and more.

Common PyTorch modules include:

  • torch module: PyTorch’s main module that contains all other modules
  • torch.nn module: provides pre-defined layers, loss functions, and activation functions for constructing neural networks
  • torch.autograd module: provides automatic differentiation for all operations on tensors
  • torch.optim module: provides neural network optimization algorithms, like SGD, Adam, and RMSProp. Pairing this with the Autograd module makes training models straightforward
  • torch.utils module: provides utilities and helper functions for tasks like data loading, model saving and loading, and performance profiling. These utilities streamline the ML development process
PyTorch modules

PyTorch modules

Parameters. Parameters are tensors that store the learnable weights and biases of a model. We can find parameters under the torch.nn module as torch.nn.Parameter.

Dynamic computational graph. First, let’s understand what computational graphs are. Computational graphs are a visual and mathematical way to represent complex mathematical expressions and operations. In a computational graph, nodes represent mathematical operations or functions, while edges represent the flow of data or inputs between the nodes. In machine learning frameworks, computational graphs describe the operations of a neural network, such as backpropagation during training.  

Example of an augmented computational graph by Preferred Networks

Example of an augmented computational graph by Preferred Networks. Source: PyTorch Blog. CC BY 3.0. No changes to the image were made.

Now, instead of using a static computational graph where you have to define the graph before running, PyTorch provides dynamic computational graphs that are defined and built on the go. This means that operations and nodes can be added or removed dynamically as the computation progresses, which improves flexibility and allows you to debug, build prototypes, and iterate on new models during the training process.

PyTorch’s dynamic graph nature is particularly useful for models with dynamic flow control, like recurrent neural networks (RNNs).

Datasets and Dataloaders. As we mentioned, datasets and dataloaders are primitives that PyTorch provides for working with large datasets.

The dataset class, available in torch.utils.data.Dataset, class represents a collection of data samples that need to be loaded.

The dataloader class, available in torch.utils.data.Dataloader, handles batching, shuffling, and parallel loading of the datasets. It helps reduce the time it takes to load datasets.

Pros of using PyTorch

Easy to use

PyTorch has a user-friendly syntax that closely resembles Python. This Pythonic nature makes it easier for beginner and experienced Python developers to learn PyTorch and use it to train neural networks.

PyTorch’s Pythonic nature also boosts rapid prototyping and helps developers test and iterate ideas quickly.

Compatible with Python libraries

PyTorch integrates seamlessly with Python libraries like NumPy, SciPy, and Pandas to simplify the process of manipulating, processing, and analyzing data.

Using the tools in Python’s rich ecosystem helps speed up development and decrease the cost of production.

Comprehensive documentation

PyTorch’s robust documentation covers basic and advanced details on how the framework works. The documentation includes installation steps, beginner tutorials on tensors, the various modules available, how to build neural networks, and more. It also provides a search bar for easy navigation and research.

Dynamic computation graphs

PyTorch’s support for dynamic computational graphs provides greater flexibility than other machine learning frameworks since the graphs can be optimized during runtime—the period during which the model is being trained or making predictions.

PyTorch’s dynamic nature also makes it easier to debug and understand how models behave during training since the computational graph gives you real-time feedback on the results of every operation. It also supports Python’s native pdb and ipdb debugging tools, which makes it easier to rectify issues. Developers can also use PyCharm—the Python IDE—for debugging.

Strong community and industry support

PyTorch has a large and active community that helps its users enjoy their shared knowledge and support. This means PyTorch users get access to several resources, tutorials, libraries, pre-trained models, and tools for visualizing and interpreting models.

Data from Stack Overflow’s 2023 developer survey shows that 8.41 percent of developers use TensorFlow, and 7.8 percent use PyTorch.

PyTorch’s community continually develops libraries like GPyTorch, BoTorch, and Allen NLP, which helps extend the PyTorch framework. Developers can also get assistance and answers to questions at the PyTrch forum or GitHub repository.

Strong support for GPU acceleration

PyTorch can leverage the parallel computing power of GPUs to accelerate the computationally intensive operations involved in training and running deep learning models.

It provides a torch.cuda module that supports integration with tools like CUDA (Compute Unified Device Architecture), NVIDIA’s parallel computing platform and API for GPUs.

This parallel processing ability, which allows you to train deep learning models faster, is especially useful for large-scale deep learning projects that require processing large amounts of data.

Good for research

PyTorch’s dynamic computational graph, Pythonic nature, and ease of use for prototyping models have made it a top choice in the research community.

Many large companies like Amazon, Tesla, Meta, and Open AI use PyTorch to power their machine learning and AI research initiatives. For example, switching to PyTorch helped Open AI decrease their iteration time on research ideas in modeling generative AI from weeks to days.

Between 2020 and 2024, 57 percent of research teams used PyTorch for their machine-learning research.

What PyTorch allows us to do is experiment very quickly. It's showing incredible promise. What we are seeing, using these new modeling techniques, we are able to take some of the problems and experiment and deploy them into production in a very short time.

Srinivas Narayanan, head of Facebook AI Applied Research. Source: Business Insider

Cons of using PyTorch

Lacks a visual interface

A major disadvantage of PyTorch is its lack of a built-in visual interface. Visualization tools are essential for monitoring and debugging neural networks, providing insights into the training process.

The lack of a visual interface means developers must use command-line tools, custom scripts, and third-party libraries to visualize their model's performance and training progress. This can increase the complexity for users, especially those new to the library or machine learning in general.

Harder to deploy models to mobile devices

The demand for deploying machine learning models on mobile devices has increased, and PyTorch has provided PyTorch Mobile, a runtime that enables developers to run PyTorch models on Android and iOS platforms.

However, PyTorch mobile is not as easy and comprehensive as TensorFlow Lite. It is still in beta, and it requires more manual implementation and configuration than TensorFlow Lite. This can be a barrier for developers looking for a seamless solution to deploy models to mobile applications.

Use cases and applications of PyTorch

PyTorch is widely used across various industries. Here are some common use cases and applications of PyTorch:

Computer vision: Computer vision enables computers to interpret and understand visual information. it helps computers analyze and extract data from images and videos.

PyTorch’s flexible and dynamic nature allows researchers to develop complex architectures like attention mechanisms and Generative Adversarial Networks (GANs). This makes it a top choice for building, training, and evaluating deep learning models for computer-vision-related tasks like image classification, object detection, and image generation (GANs).

Natural language processing (NLP): NLP deals with a computer’s ability to understand, interpret, and generate human language. PyTorch’s ability to process and model text data is a top choice for NLP tasks like text classification, language translation, and sentiment analysis.

NLP is widely used to power AI chatbots, personal assistants, and translation software like Google Translate.

Speech and audio processing: This involves using algorithms to analyze, understand, manipulate, and generate speech and audio data.

PyTorch is used in this domain because of its flexibility, efficient tensor operations, and the libraries it offers, like Torchaudio, for processing speech and audio data.

Key speech and audio processing tasks that can be powered by PTorch include speech recognition, speech synthesis, audio classification, and music generation.

Generative models: Generative AI models have taken the world by storm, with GANs and transformer-based models being two of the most widely used ones.

These models are responsible for tasks like image and text generation, style transfer, and data augmentation, and are responsible for the “magic” we see on AI image generation platforms like Midjourney and AI chatbots like ChatGPT.

Companies using PyTorch

PyTorch is being used to identify cancer cells, discover new drugs, build video games, make self-driving cars, and more. It is powering AI and ML research at academic institutions and Fortune 500 companies alike. Here are some businesses using PyTorch.

Genentech

Genentech uses PyTorch to develop new drugs by searching through millions of chemical structures, building models that predict how patients will react to treatments, and developing vaccines for different types of cancer.

Genentech also works with the PyTorch community and contributes to its development.

Uber

Uber uses PyTorch to build Pyro, a deep probabilistic programming language. They also use it to deploy hundreds of PyTorch models for applications like estimated time of arrival (ETA) predictions, demand forecasting, and menu transcriptions for Uber Eats.

Amazon

Amazon uses PyTorch to flag ads that don’t comply with the content guidelines of its advertising service. They use PyTorch to develop computer vision and natural language processing models that automatically flag potentially non-compliant ads.

Integrating PyTorch with their internal infrastructure allows Amazon to provide an optimal customer experience and create ML models that are fast enough to serve ads in milliseconds.

Meta

Meta uses PyTorch to train its language translation system that handles 6 billion translations a day and to power its Instagram’s recommendation system that suggests new content (Feeds, Stories, or Reels) that match your taste.

Tesla

Tesla uses PyTorch to train and deploy deep learning models for Autopilot—their self-driving technology—and to power features like lane-keeping assistance, object detection, and Smart Summon for their cars.

How Computer Vision Applications WorkPlayButton

How Computer Vision Applications Work

Airbnb

Airbnb uses PyTorch for their customer service department’s dialog assistant. The assistant is powered by a sequence-to-sequence model (built with PyTorch) that delivers smart replies to customers and helps improve customer interactions.

Comparing PyTorch with other machine learning libraries

While PyTorch is a top choice in the AI and machine learning domain, it's not the only option, as alternatives like TensorFlow and Keras exist. Let’s explore how PyTorch compares to its alternatives.

PyTorch vs TensorFlow

TensorFlow is a low-level open-source library for implementing machine learning models, training deep neural networks, and solving complex numerical problems. It was created by Google Brain Team and released in 2015.

Types of computational graphs. PyTorch uses dynamic computational graphs, while TensorFlow (versions lower than 2.0) uses static computational graphs.

PyTorch’s dynamic computational graph provides greater flexibility and ease of experimentation during model development. It allows developers to modify the graph on the go. On the other hand, TensorFlow’s static computational graphs offer better optimization opportunities for production-ready models.

Deployment. TensorFlow is the winner when it comes to deploying trained models to production. It offers TensorFlow Serving, an in-built model deployment tool that Google uses for its projects. Torchserve is available for deploying PyTorch models. Torchserve doesn’t offer as many features as TensorFlow. However, its latest version supports HuggingFace, AWS Cloud Formation, Nvidia Waveglow, and others.

Learning curve. PyTorch is easier to learn and takes a “Pythonic” approach. This means that it sticks closely to Python and uses core Python concepts like classes, structures, and conditional loops. TensorFlow, on the other hand, has a more complex syntax than PyTorch, which can make it difficult for beginners to understand.

Research friendliness: While both frameworks are used extensively in research, the flexibility offered by PyTorch, combined with its Pythonic nature, makes it a favorite for many researchers. They can easily tweak models, try new architectures, and experiment without much boilerplate.

Popularity. As the older library, TensorFlow has a larger community and more learning resources like tutorials, courses, and books. As of this writing, it also has more GitHub stars (183 thousand) than PyTorch (79.1 thousand).

PyTorch vs Keras

Technically speaking, Keras is not a library. It is actually a high-level deep learning API for developing neural networks. It was developed by Google in 2015, is written in Python, and can run on top of Microsoft CNTK, TensorFlow, and Theano.

Level of abstraction. PyTorch provides a low-level and flexible approach, allowing developers to have more control over the model architecture and training process. Keras, on the other hand, prioritizes simplicity and abstraction. It provides a user-friendly interface that abstracts away many of the low-level details, making it easier to prototype and build models quickly. This high-level abstraction makes Keras more accessible to beginners and researchers who may not have extensive experience with low-level tensor operations.

Computation graphs. Similar to TensorFlow, Keras has a static computational graph. When using the functional API or the Sequential model in Keras, the computational graph is defined before execution, and any modifications require redefining the graph.

While Keras takes the static approach by default, it also supports eager execution, which allows for dynamic computation graphs similar to PyTorch.

Popularity. PyTorch’s dynamic computation graphs, low-level control, and tight integration with Python make it the preferred solution for academic research and industry applications.

Keras’s high-level abstraction and ease of use have made it a popular choice for rapid prototyping and experimentation and for those new to deep learning. It currently has 61.1 thousand GitHub stars.

Getting started: how to install PyTorch

You’ve learned about the basics of PyTorch and how it works— great! Now, let’s explore how to get started with it.

See Also

A good starting point is the official “get started” guide, which covers how to install PyTorch and its dependencies.

There’s also the PyTorch documentation. Explore it for comprehensive information on how PyTorch works and its functionalities.

Want to learn more about PyTorch basics, best practices, and more? Then, explore the Tutorials section. It includes various examples of PyTorch in action, a cheat sheet, and links to example tutorials on GitHub.

You can also join the PyTorch developer community to stay updated on the latest developments, contribute to PyTorch, interact with fellow PyTorch developers, and get assistance.

There’s also the ecosystem page, which is a directory of the libraries and tools in the PyTorch ecosystem.

This post is a part of our “The Good and the Bad” series. For more information about the pros and cons of the most popular technologies, see the other articles from the series:

Comments