How Deep Learning & GBMs Can Change the Tabular Data Landscape
Zian (Andy) Wang
Deep Learning, with its broad range of applications spanning fields such as language processing, image recognition, and audio modeling, has become a highly visible and widely recognized aspect of artificial intelligence. The recent buzz surrounding models like ChatGPT and the impressive capabilities demonstrated by diffusion models have only attracted more attention to using Deep Learning for image and language-related tasks. However, despite the significant advancements made in these areas, modern Deep Learning methods still lack research and applications when it comes to effectively handling one of the most crucial and fundamental types of data: tabular data.
Many bold research papers, online courses, and blog posts are under the assumption that "classical" tree-based models and gradient boosting machines (GBMs) outperform Deep Learning methods on tabular data. Of course, this is not to suggest that there isn't solid evidence that supports the use of tree-based approaches over Deep Learning. But far too frequently, this nuanced research is misinterpreted and viewed as a general norm, and people who dislike Deep Learning often adhere to the same flawed ideology as those who are improving the state of Deep Learning: taking insights gained within a typically well-defined set of restrictions and applying them outside of those limitations.
📊 Background: Tabular Data has Changed
In retrospect, it’s easy to assume the inability of Deep Learning approaches for tabular data, but times have changed. Tabular datasets account for a majority of day-to-day data analysis and processing. It is not unreasonable to suggest that tabular data is the most varied out of all data forms. The domain space that these datasets encompass has greatly been enlarged due to advancements in data collection techniques.
Despite this, many well-publicized research surveys and more informal investigations showing the success of tree-based models over Deep Learning methods utilize standard benchmark datasets – the Forest Cover dataset, the Higgs Boson dataset, the California Housing dataset, the Wine Quality dataset, and so on. These datasets, even when evaluated in the dozens, are undoubtedly limited. They may have been relevant a few decades ago, but they are a far cry from representing the entire tabular data domain. Of course, we must acknowledge that performing an evaluative survey study with poorly behaved diverse datasets is much more difficult than more homogenous benchmark datasets. Yet those who tout the findings of such studies as bearing a broad verdict on the capabilities of neural networks on tabular data overlook the sheer breadth of tabular data domains in which machine learning models are applied.
The specification of "baseline" models in studies that argue against Deep learning for tabular data is another prevalent issue. Gradient Boosting Machines (GBMs) and the like have been the focus of research and implementation for decades, with decision trees existing as early as the 1960s and GBMs being in use for over a decade. The "baseline" models utilized for benchmarking in the realm of tree-based models are, in fact, highly optimized state-of-the-art techniques. The neural networks pitted against these tree-based models, on the other hand, tend to be simple MLPs and standard architecture variants; they fail to adequately demonstrate the intricate infrastructure and potentialities of the Deep Learning paradigm.
🎤 The State of Affairs in Tabular Data
With the rise of biological, biochemical, and medical advances, the majority of tabular data modeling has shifted from predicting house prices and classifying plants to discovering the complex relationships between biological and chemical processes. These datasets are often filled with hundreds of features interacting with each other in innumerable ways that a simple decision tree cannot model. Additionally, social media platforms are taking the world by storm, increasing the importance of accurate content recommendation algorithms–another form of tabular data modeling. The tabular data space is tremendously more diverse than some might assume–it has changed since the early 2000s and 2010s. Datasets nowadays are ever so sophisticated, capturing incredibly complex phenomena.
It is important to note that tree-based models are not necessarily "bad choices" for modern tabular data modeling. Tree-based models and Gradient Boosting Machines still possess various advantages and are beneficial for quick baselining and testing. They are arguably faster to train than neural network ensembles and are way more interpretable. Furthermore, the deterministic nature of tree-based models closely aligns with our logic for problem-solving in the real world. For example, a tree-based model may be a good fit for classifying whether today's conditions fit playing golf outdoors. Some of its decision nodes might look like this:
But, as we know, modern tabular datasets aren't about deciding whether today's a good day to go outside or not. This is where tree-based models fall short.
🐢 Limitations of Traditional Approaches
Simple supervised Learning is not the sole problem at hand with modern tabular data modeling. Datasets are often noisy, and we need ways to either denoise the data or develop ways to be robust against noise. Tabular data is volatile and often changes with time, and we need models which can adapt and generalize to different scenarios. Though technologies have advanced, sometimes there are insufficient resources, or the nature of the problem prevents a significant amount of data from being collected. We need ways to generate synthetic data while being realistic. As far as we know, tree-based models cannot accomplish these tasks or have great difficulty doing so. They are prone to considerable overfitting due to their specificity and somewhat deterministic nature. They aren't great candidates when it comes to large and complicated datasets.
We can find many more examples if only we look. Many tabular datasets contain text and image attributes, such as an online product reviews dataset that includes a textual review and an image of the product, along with the usual information represented in a tabular fashion. When predicting house listing prices, it may be extremely helpful to include interior images along with the associated standard tabular information such as the square footage, number of bathrooms, and so on.
Alternatively, consider stock price data that captures time series and company information in tabular form. What if we also add the top ten financial headlines in addition to this tabular data and the time-series data to forecast stock prices? Tree-based models, to our knowledge, cannot effectively address any of these multimodal problems; they're extremely limited in terms of versatility and adaptation to different types of data.
Of course, neural networks or Deep Learning models, in general, aren't perfect; there are various legitimate concerns regarding them:
Interpretability: As mentioned above, tree-based models follow a decision process similar to human logic. This makes them naturally interpretable. Through analyzing the branches and nodes of a decision tree (or GBMs), hidden patterns and correlations within the dataset can surface; this knowledge can then be further used to explain why and how a model makes a particular prediction and discover possible bias in the model/data. Neural networks have been well-known to be black box models that are either difficult or impossible to reason about how their decision-making process is structured.
Lack of data: Deep Learning models are known to be data-hungry. It typically requires a vast amount of data to train to achieve the promised performance on state-of-the-art architectures. On the other hand, tree-based models can easily handle relatively small datasets without overfitting them–at least not as much as neural networks do.
Inability to preprocess data: neural networks can't effectively preprocess data in a way that reflects the practical meaning of the features. Take CatBoost, whose name is inspired by its preprocessing scheme: it can accept categorical data without encoding since it has a pipeline for processing categorical data built into the model.
🧠 Deep Learning for Tabular Data
But again, this isn't the 2010s anymore. The objections against Deep Learning models mentioned above are actively being researched, addressed, and improved upon. Recent studies from Caglar Aytekin demonstrate that neural networks can be understood like any other decision tree: a massive leap in the field of neural network interpretability. Although an older technique, but still extremely useful: activation maximization can be used to determine what kind of label the network is "looking for" by finding feature representations that neurons have learned in the network.
Tabular GANs have been explored in the past few years, with TGAN appearing as one of the first "official" GANs dedicated to generating synthetic tabular data, its performance exceeding Variational Autoencoders (VAE), a universal data generation method. CTGAN –a conditional tabular GAN then follows this–is now one of the most prominent and actively used models. CTGAN and its variants provide another extremely useful and essential feature in terms of tabular data modeling–anonymization. To protect user and customer privacy, features such as names, addresses, and other collected information regarding one's identity or status cannot be used to train machine learning models. CTGANs can synthesize completely made-up samples without deviation from the original dataset's distribution, patterns, and correlations; you name it. Another relatively new study proposed GReaT and distilled-GReaT, which uses large language models (LLMs) for conditional tabular data generation. Instead of having a few booleans and numbers that condition the generation, GReaT can synthesize tabular data based on text descriptions, streamlining the data generation process for those who are less-technical inclined.
Incorporating the best of both worlds, many models have applied the tree-based modeling concept to Deep Learning models such as GrowNet, Deep Neural Decision Trees, Neural Oblivious Decision Ensembles, and XBNet, to name a few. Researchers have also been "stealing" the success of attention-based Transformers for tabular data with models such as TabTransformer, TabNet, and SAINT. Notably, these models also incorporate "preprocessing" approaches with TabNet, implementing a feature selection pipeline for categorical features, TabTransformer, and SAINT focusing on generating contextual embeddings. Finally, all three models recommend semi-supervised pre-training before learning specific task predictions.
There haven't just been "promises" from papers and researchers; we are starting to observe various successful applications of tabular data modeling using Deep Learning. In the Kaggle Mechanism of Action (MoA) competition, competitors classified drugs based on biological activities. In the early stages of the competition, many attempted to use GBMs and tree-based models, but to no avail. The dataset was largely imbalanced, and it's multi-label, meaning that each sample can simultaneously correspond to multiple binary labels. Tree-based models cannot predict multiple targets simultaneously, and one model must be trained for every label. This increases training costs and neglects the opportunity to extract correlations between different labels.
Consequently, the top solutions of the competition were almost all composed of Deep Learning approaches. The MoA competition also popularized many tabular data modeling techniques and approaches through Deep Learning, including CTGAN, GrowNet, TabNet, and CNN-based pipelines. Following the MoA competition, the Jane Street Market Prediction hosted on Kaggle –where participants predicted stock trends based on anonymized features–is another real-world example of using Deep Learning for tabular data. The winning solution was a simple Multi-Layer Perceptron (MLP) along with an autoencoder to generate denoised versions of the original features. Again, no tree-based models or any "classical" ML models came even close.
🎓Takeaways: One Approach Need not Reign Supreme (Is it Time to Graduate to Deep Learning?)
What was discussed throughout the article is not comprehensive enough to cover everything related to applying Deep Learning for tabular data: it's not even close to praising its potential. There are many more aspects to explore, such as neural networks' ability to perform real-time updates based on live data–a framework referred to as online Learning–or applying Convolutional Neural Networks and Recurrent Neural Networks to tabular data. Meta-optimization is another caveat for neural networks as they have many tunable parameters that can "make or break" the model. However, tree-based models suffer from the same problem to some extent in terms of hyperparameter selection.
At last, this is not to say that tree-based models are superior to Deep Learning models or to suggest that Deep Learning models are generally better performing than tree-based models. Instead, this demonstrates the potential of Deep Learning models for solving predictive tasks involving tabular data. Some of the ideas presented above were based on the information covered in the introduction in my published book with Apress, "Modern Deep Learning for Tabular Data". We cover a wide range of research, theory, and applications by exploring the whole ecosystem surrounding tabular data modeling through Deep Learning by building everything from the ground-up. In the world where Deep Learning is taking over by storm, the importance of tabular data modeling cannot be ignored and with modern techniques, tabular data modeling’s potential is unbounded.