Deep Learning for Tabular Data
In this post, I will discuss why deep learning is less effective for tabular data compared to other data modalities, followed by some reasons why deep learning for tabular data is still a worthwhile consideration.
Almost anyone who has worked with data before in any capacity has come across tabular data. Tabular data is so ubiquitous that in many situations, when someone refers to “data” they are talking specifically about tabular data. To clear any confusion, I am using the term “tabular data” to refer to the kind of data typically stored in a CSV file: two-dimensional data where rows are data samples and columns are (usually named) data features. It’s the kind of data that “traditional” machine learning algorithms like decision trees, linear/logistic regressions, and SVMs are designed to work on.
Meanwhile, deep learning, i.e. machine learning with neural networks, has completely revolutionized many domains, most notably computer vision and natural language processing (NLP). These fields rely so heavily on deep learning that building a model without a neural network is completely unheard of in the current state of the field. Yet for tabular data, neural networks are usually thought of as inferior models to state-of-the-art traditional machine learning models such as XGBoost and Random Forest. Why is it that deep learning is unanimously the best model for many fields and yet not for tabular data?
This is a question that I personally had a hard time answering when I started to get into machine learning. One the one hand, people in statistics love tree-based models, and these models, in particular XGBoost and other gradient boosted decision tree algorithms, are always winning Kaggle competitions involving tabular data. On the other hand, there’s such a buzz about deep learning among computer scientists who have demonstrated their remarkable capabilities in computer vision, NLP, and elsewhere. It begs the question: why is tabular data different from these other kinds of data, and why does deep learning not work as well for tabular data? In this post, I will first explore some explanations to this question, and then I will explore some reasons why studying deep learning for tabular data is still worthwhile.
Reasons for Hesitation
The first step towards understanding deep learning for tabular data is understanding why deep learning is not the unanonymous model choice for tabular data as it is for other data modalities. I will provide three reasons why I believe this is the case.
1) Dataset Size
Let’s think for a little about the size of tabular datasets versus datasets where deep learning thrives. Tabular datasets come in all forms, but most often they are long and skinny, that is, many data samples and many less data features. There are certainly tabular datasets where there are more features than data samples (e.g. biological datasets), but this is the minority. Even when there are many features, it is often the case that only a subset of those features are actually important and/or are not redundant. Now compare this to an image dataset, for example CIFAR-10. One image is a \(32 \times 32\) grid of pixels for each of the 3 color channels, i.e. \(32 \cdot 32 \cdot 3 = 3072\) features per image. Further, many of these features are important towards classifying each image. With 60000 images and each image with 3072 features, many of which are critical, this is a lot of data, especially compared to a comparable tabular dataset with 60000 rows (an image is worth 1000 rows?). And CIFAR-10 is a small image dataset compared to larger datasets like ImageNet. The takeaway here is that domains where deep learning thrives often contain datasets much larger than your average tabular dataset, in length and especially in width. For these massive datasets, neural networks are extremely scalable and capable of absorbing tremendous amounts of information, which is necessary for learning the extremely complex relationships in these datasets. With smaller tabular datasets, however, neural networks often do not yield noticeably better performance since there is not as much signal to absorb, and since neural networks are significantly more expensive and slow to train, they are often not worth using.
2) Tabular Data Heterogeneity
Let’s revisit CIFAR-10: there are 3072 unique pixel features, but each of these pixel values is an integer in the range \([0, 255]\). In tabular datasets, on the other hand, each column can be either numerical, categorical (ordinal or nominal), or even text. This means that for neural networks, we have to find a way to encode the data so everything is numerical, since neural networks can only process numerical data. The typical approach in a neural network is to learn “embeddings” for the categorical data during model training, but learning good categorical embeddings is not always possible without large amounts of data. Meanwhile in tree-based models, categorical features can be handled natively, so we don’t have to worry about encodings or embeddings. Therefore, this extra preprocessing, especially for small to medium sized tabular datasets, can hold back the performance of deep learning compared to tree-based methods for tabular datasets.
Not only is there heterogeneity among the features in a tabular dataset, but there is also heterogeneity between tabular datasets themselves. Two tabular datasets almost never have the same features, and even if they have a feature in common, that feature is likely to have a different interpretation in each dataset. If you are used to tabular data, this is entirely normal, but in other data modalities, the data is actually much more similar across datasets. Two image datasets might differ in the type of images, resolution of images, etc. but at the end of the day they are all still images composed of pixels which lie in the range \([0, 255]\). Two text datasets might contain very different kinds of text structure, but the text is still likely from the same vocabulary of words. Similarity between datasets is relavent here for two reasons. First, it more easily allows for transfer learning, i.e. pretraining models on one dataset/task and then “transferring” that knowledge to a different downstream dataset/task. In other modalities where datasets are more similar, transfer learning between different datasets is reasonable and is a key to the success of many state-of-the-art deep learning systems. With tabular data, however, transfer learning between two datasets does not really make sense since the two datasets will have very different features. Second, it makes it harder to incorporate inductive biases into models. Deep learning really hit its stride in computer vision when convolution neural networks (CNNs) were designed, and similarly in NLP when recurrent neural networks (RNNs) were designed. One benefit of neural networks is that they are very flexible in design, since the only design restriction in theory is to use matrix/tensor operations that are differentiable. Therefore, neural architectures can be designed to incorproate inductive bias about the domain at hand. This is clearly evident in CNNs, for example, because they incorporate a spatial inductive bias about image data. With tabular data, it is much harder to design a neural architecture which takes advantage of an inductive bias since all tabular datasets are very different.
3) Interpretability
Lastly, often times the questions around tabular datasets are not the same as the questions in deep learning. In deep learning, we usually care about prediction tasks of some kind, and we usually only care about having a model with the best predictive accuracy. With tabular data, however, we may be interested in not only predictive accuracy but also model interpretability. That is, we often care about how much each of our features impacts the model, since each of our features usually represents something meaningful based on the name for that column. In other data modalities, this is less often the case; for example, we are usually not interested in the role of an individual pixel in a computer vision task. Neural networks are not very interpretable, so if we do care about interpretability, deep learning is just not the right tool to work with, and we care about interpretability much more often with tabular data.
Reasons for Optimism
The above three sections point towards deep learning being not as effective for tabular data, and this is largely true. Yet there are some reasons why exploring deep learning for tabular data is actually worthwhile. After all, deep learning is so effective for other types of data that it is worth figuring out if there are any scenarios in which deep learning is a worthwhile tool for tabular data.
1) More Data and Better Hardware
The point still stands that the average tabular dataset is smaller than the average computer vision dataset. However, we all know that data continues to grow, and there are plenty of absolutely massive tabular datasets, in which case the scalability of deep learning may provide superior results compared to tree-based algorithms. The problem remains that neural networks are slow and expensive to train, but hardware for training neural networks continues to improve dramatically, whereas the same cannot be said for training tree-based models. When training a tree-based model, the best hardware to use is many CPU cores, since the tree ensembling can be parallelized. With neural networks, the best hardware is GPUs and TPUs, which are super-parallelized and optimized for matrix/tensor operations common in neural networks. Even with the best GPUs, neural networks may still be slower than tree-based models, but GPUs continue to get better, so it is possible that this might change in the future. Another advantage of deep learning in this massive dataset domain is that the entire dataset does not need to be loaded in memory, since neural networks usually optimize one mini-batch at a time. All this is to say that if you have a massive tabular dataset and expensive hardware, it is definitely worth considering deep learning.
2) Transfer Learning
I briefly mentioned transfer learning when discussing why tabular dataset heterogeneity makes deep learning less successful, but I want to discuss it in further detail here since I think it is a really important concept. Simply put, transfer learning is when a model is trained on one dataset and then “transferred” to another dataset/task for additional training and/or to solve a downstream task. A great example of this is word embeddings in NLP like GloVe: they are learned from massive datasets of raw text data and are then applied as the first layer for many NLP tasks. This kind of model “pretraining” and later “fine-tuning” is really powerful for a number of reasons. First, a model trained once on a massive dataset, which will be very computationally expensive, can then be transferred to many downstream tasks. For each downstream task, the bulk of the training has already been done, so the fine-tuning is a much easier computational task than training from scratch. Second, pretraining can take advantage of unlabeled data by creating some kind of “self-supervised” or “pretext” supervised learning task from the unsupervised data. The basic idea is to artificially create a supervised learning task out of unsupervised data, allowing training using supervised techniques. If you want to know more about transfer learning, this article and the sources therein is a good place to start, and similar this post for self-supervised learning.
As mentioned before, the fact that tabular datasets are all very different from each other makes transfer learning between two tabular datasets unreasonable. However, transfer learning without the same tabular dataset still makes sense. This can be useful in a semi-supervised context where only a small subset of data points are labeled, since pretraining can be done on the unlabeled subset of the dataset. A critical point, though, is that transfer learning is a concept that only really makes sense for neural networks. This is because neural networks have hidden layers, and these hidden layers in a sense contain latent features about the training data. Latent features learned on one image dataset, for example, are probably useful for a different image set, because at the end of the day both datasets are of images. With tree-based models, or any other traditional machine learning method, there really isn’t anything comparable. If a model doesn’t construct these type of latent features, then there isn’t really anything that you can transfer to a new dataset and/or task. Thus, the takeaway here is that transfer learning can still be useful for tabular datasets in a semi-supervised context, but we can only take advantage of transfer learning if we use deep learning.
3) Multimodal Applications
A growing area of research over the past few decades is multimodal machine learning. Multimodal here means that the data comes from multiple modalities, and the goal of multimodal ML is to find the best model to incorporate all available modalities into a predictive model. The modalities available can be wide-ranging, but some of the most common are video, audio, image, text, and tabular. A fairly straightforward baseline approach for this would be employ a model for each modality separately and then combine the models using some sort of ensembling approach. This baseline can easily be implemented using traditional machine learning models, but to do more advanced multimodal learning approaches, we need more flexibility in how to combine modalities. This is again where deep learning provides something that traditional ML approaches cannot. The flexibility of neural architectures gives us the ability to combine modalities any way that we see fit.
Conclusion
At this point, I have presented some reasons why deep learning is not as effectuve for tabular data compared to other domains, and also some reasons why deep learning has some merit with tabular data. Hopefully, this post has helped explain why you often don’t see deep learning used in traditional machine learning spheres, but also why deep learning is an important player in the tabular machine learning space. In a future post, I will talk about some of the state-of-the-art deep learning models for tabular data to give some insight into how one might use deep learning for tabular data.