Note: This is the second part of a series where we take a deeper dive into the question of data drift detection. If you haven't yet, check out the first part where we discussed data drift in the context of tabular data!
Introduction
Data drift detection is a key component of a machine learning monitoring system. So far, we’ve discussed what data drift can look like in the context of tabular data, as well as some approaches to measuring drift. To recap, let’s revisit a simple example of data drift in a single feature:
In this case, the distribution of age in the training dataset is different from its distribution in a production environment. Over time, the performance of a model using age as an input feature can decay in response to the change in the environment the model is deployed in. There are a variety of metrics we can use for measuring the difference in these two distributions, but how do we measure drift without structured features? Systems trained on unstructured data, like text or images, face the same risks when deployed in production. However, detecting drift in these scenarios is more subtle, as we cannot use common divergence metrics on the raw data. In this post, we’ll walk through a general framework for data drift detection with unstructured data and we’ll highlight the two example use cases of NLP and computer vision.
Our first example will be a computer vision use case, where the goal is to classify images based on the objects depicted in the image. For this setting, we used the STL-10 dataset from Stanford which provides high-resolution images from ten different possible classes including airplane, bird, dog, truck, and so on.
Our second use case will be in NLP and we used a News Headline dataset which contains news headlines along with their respective topics such as crime, entertainment, world news, comedy, etc. Here, our objective is to classify headline text to the correct category.
Algorithm/Approach
As with measuring multivariate drift in tabular data, the core motivation of the approach is to model the density, or distribution, of the reference dataset.
Overview
There are several different approaches for finding anomalies in unstructured data. For any given approach, the three main aspects to determine anomalies in unseen data require:
In this section, we will discuss the variety of different techniques used for each of these three different components. Further, we will highlight example results with NLP and computer vision datasets.
Vector Representation
We must convert our image or text data into a meaningful vector representation in order to understand the underlying distribution of the reference dataset. These vector representations are a type of feature extraction that can capture a useful representation of our unstructured data. Transfer learning is one approach for creating these representations by extracting embeddings of each image or text sequence from a large pre-trained model. These large-scale models are generally trained on millions of different datapoints and use state of the art architectures (CNN’s for image data or Transformers for text data) that can take unseen datapoints and produce a meaningful vector representation. For images, pre-trained models such as ResNet, VGG, or similar will be appropriate. For NLP data, we need to extract document embeddings and turn to pre-trained (or fine-tuned) Large Language Models.
While these are just a few examples of large-scale pre-trained models, there exist several others which are trained on different neural-network architectures and different datasets. This approach can be used with any type of vector embedding as long as it is meaningful for the context of your machine learning task.
Density Model
Once we have meaningful vector abstractions for every point in our reference dataset, we must now create a density model that can model the underlying distribution. We can train a flexible density model to these embedding vectors. This could be accomplished with many possible techniques such as an auto-encoder, a VAE, a Normalizing Flow, a GAN, etc. In each case, this density model learns the structure and distribution of the reference set images or text (as represented in the embedding space).
As an example, auto-encoders are frequently used for unsupervised anomaly detection. Auto-encoders learn the latent representations of the reference set (consisting of vector embeddings) by encoding the vector to a lower dimensional vector and then decoding that representation back to its original dimension. We refer to the error measurement between the original input vector and the output vector as the reconstruction loss. Datapoints that are similar to points from the reference distribution will have a lower reconstruction error than points that are very different from the reference distribution. This property is useful for finding outliers as points that are outside the distribution of the reference set will have a high reconstruction error.
Taking a look at our news headline example, we can inspect the space learned by our auto-encoder. We first train the model on news headlines categorized as CRIME, which we treat as our in-distribution data. Below is a visualization of held-out crime headlines, as well as entertainment headlines.
Scoring
Once we have trained our density model on our reference set, we must find a way to convert the reconstruction loss values from the model to actionable anomaly scores. Our approach is outlined below:
The motivation for our approach is twofold:
Evaluation/Results
There are very few open-source datasets that have labeled data to measure anomaly detection for unstructured data types. Therefore, we constructed a few different test cases with our example datasets introduced earlier in this paper to measure the efficacy of our anomaly detection algorithm for unstructured data.
For each dataset (News Headlines and STL-10), we broke up our test cases as follows:
We highlight two graphics below showcasing the results of our experiments.
The figures above are showcasing the ROC curves for two specific experiments we ran using the STL-10 dataset. The graph on the left is measuring the AUC when the in-distribution dataset (non-anomalous) was taken from a set of ship images while the out-of-distribution dataset (anomalous) was taken from a set of plane images. Similarly, the graph on the right shows the ROC curve where the in-distribution dataset was taken from a set of bird images and the out-of-distribution dataset was taken from a set of plane images. We notice that for both experiments, the anomaly detector does a very good job (AUC scores of 0.804 and 0.996) of being able to differentiate between in-distribution and out-of-distribution datapoints.
The heatmap above is reporting the AUC scores for all possible pairwise experiments between possible classes in the news headline dataset (such as Crime, Entertainment, etc.). For any given cell in the heatmap, we are reporting the AUC score where the category on the x-axis is the in-distribution (non-anomalous) dataset while the category on the y-axis is the out-of-distribution (anomalous) dataset. We reported an average AUC score (across all crosswise pairs) to be 0.83, which is quite impressive given this task is difficult even for humans.
Conclusion
This approach to out-of-distribution detection is especially powerful because it is completely unsupervised. In a production environment, we often don’t have prior knowledge of what kind of distribution shifts to expect or access to labeled data. Additionally, while we have considered two classification problems in this post, this technique can be applied to any type of machine learning task, as it only considers the input data and is therefore independent of the underlying ML task.
Detection of out-of-distribution samples is only the first step in maintaining a robust machine learning system. Monitor your ML model for drift with Arthur. At Arthur, we’re helping data scientists and machine learning engineers detect, understand, and respond to unforeseen production environments.
FAQ
How do changes in the external environment, unrelated to the main features of the model, affect the process of data drift detection in NLP and CV applications within AI and ML frameworks, and how can one adjust the detection mechanisms to accommodate such changes?
Changes in the external environment can significantly impact the effectiveness of AI applications, particularly in NLP and CV, by introducing new patterns or visual trends not present in the training data, leading to higher rates of misclassification or irrelevant results. This indicates a data drift. To adjust the detection mechanisms within AI and ML frameworks, one could incorporate adaptive learning strategies, allowing the model to periodically update its parameters based on new data. Additionally, implementing a robust anomaly detection framework capable of identifying and adapting to sudden shifts in data distribution without human intervention might help. Regularly updating the datasets with recent examples and employing domain adaptation techniques are also effective strategies to mitigate the effects of external changes on the performance of ML models.
What are the specific computational costs associated with implementing the described data drift detection methods for unstructured data in real-world applications within AI and ML domains, and how do these costs compare with traditional data drift detection methods used for structured data?
Implementing data drift detection methods for unstructured data in real-world AI and ML applications can be significantly more computationally intensive than for structured data. This is primarily due to the complexity of processing and analyzing unstructured data, such as images and text, which requires advanced algorithms and increased computational power typical of AI systems. Techniques like vector representation, density model training, and anomaly scoring, integral to ML workflows, are resource-intensive, especially when handling large datasets. In comparison, traditional data drift detection in structured data, often found in classical ML scenarios, relies on less computationally demanding statistical methods or simpler models. However, the exact computational costs can vary significantly depending on the specifics of each AI and ML implementation, the frequency of model updates, and the volume of data being analyzed.
How does the data drift detection framework integrate with existing machine learning pipelines in AI systems, particularly in automated environments where continuous monitoring and instant decision-making are essential?
The data drift detection framework can be integrated into existing ML pipelines within AI systems as a dedicated monitoring layer that functions in parallel with the main data processing workflow. In automated AI environments, this involves the continuous, real-time analysis of incoming data to assess its conformity to the model's initial training distribution, a cornerstone in ML operations. The framework should trigger alerts or initiate a retraining cycle if significant drift is detected, maintaining the ML model's accuracy and relevance. For effective integration, APIs could be developed to direct data from operational activities straight into the drift detection system and automate responses based on the outcomes, thereby enhancing the AI system's responsiveness and reliability. This ensures that the ML models remain accurate and relevant without disrupting the overall operational flow of AI-driven systems.