Fast Counterfactual Explanations using Reinforcement Learning
By:
Karthik Rao and Sahil Verma
|
March 18, 2022
Introduction
Counterfactuals, an active area of research in machine learning explainability, are explanations that produce actionable steps to move a data point from one side of a decision boundary to another. These explanations have a clear use-case for several applications ranging from loan decisions (example shown below) to healthcare diagnosis.
Problem: A binary classifier, being used by a financial services company, predicts if an individual should be approved (1) or denied for a loan (0). Individuals want to know how to get approved for a loan if the model predicts that they should be rejected.
Counterfactual Explanation: We can now provide a series of steps (as changes in the input features) that can help an individual get approved for a loan, e.g. add $10,000 to your salary and gain 2 more years of education.
In practice, we might have many more requirements of CFEs beyond just finding the other side of the decision boundary. For example, we might need the features to be constrained to only certain actionable sets, or we might need the resultant counterfactual to be realistic and similar to the training data. Recent work has summarized these various desiderata and the common research themes in the field. Additionally, we need to be able to compute counterfactuals in way that is computationally efficient for high-data-volume uses cases. In this article, we present a distributed and scalable reinforcement learning framework that can produce real-time counterfactual explanations for any binary classifier. We provide an overview of the algorithm used to find counterfactual explanations in real-time, and implementation details of how we've used Open AI Gym and Ray (RLLib) to put this work into practice.
Counterfactuals Overview
We define a counterfactual explanation (CFE) as follows:
Definition.Given a model \(f\), and a data point \(x\), where \(f(x) = y\), we want to find a point \(x'\) where \(f(x') = 1 - y\) such that we minimize \(dist(x, x')\) for any distance function \(dist\).
This definition states that we want to find a point \((x')\) where our binary classifier returns the inverse of the original classification \((y)\). This definition is also malleable, as we can model different causal relationships of features in order to produce more realistic counterfactual explanations. For example, we can define a set of immutable features which cannot be changed (such as gender and marital status), and a set of non-decreasing features (such as age and education level). We also want to find a counterfactual explanation \((x')\), such that \(x'\) is close to the original training data. Such additional constraints allow us to produce counterfactual explanations that are realistic and well-defined.
One of the key shortcomings of most counterfactual approaches is that they are computed on each individual instance in a data set. This means that for any set of inferences, we must solve an optimization problem to find a counterfactual explanation that has the following properties:
Correct Output: We need the new data point \(x'\) to have the desired output from the model (aka moving from a classification of 0 to 1).
Minimize Data Manifold Distance: We want to minimize the distance between our new counterfactual data-point \((x')\) and our original training dataset (through KNN distance).
Adhere to feature constraints: We have to respect the properties of features especially if they are immutable (cannot be changed) or non-decreasing (must only increase)
Solving a new optimization problem for each data point can be expensive, preventing us from creating real-time counterfactual explanations. Our approach creates a model that allows us to pay for this expensive training upfront for the entire dataset and produce fast inferences for any new data point that needs a counterfactual explanation.
Reinforcement Learning Framework for CFEs
Based on our goal of achieving near real-time explanations, we considered a reinforcement learning framework [Verma, Hines, Dickerson] that allows us to do a one-time training of the model and produce explanations in real-time. In this section, we will present a brief overview of reinforcement learning and how we applied it to counterfactual explanations.
Reinforcement Learning Overview
Reinforcement Learning is a machine learning framework that allows an agent to interactively "learn" a best set of actions for a desired state in a given environment. Reinforcement Learning has had the most success in robotics and game-playing scenarios (we have seen RL models beat the best players in the world at Poker and Go).
A few key terms will be used through this article as we model our counterfactuals problem:
Agent: The agent operates in the entire state space (environment) for a problem. The agent aims to make a right series of actions through the state space to perform some task. It will sample the environment in order to learn the best actions in any state.
Environment: A generic set of states that produce a reward when an agent moves from state to state. The agent will only learn about the reward from the environment once it has arrived at a certain state.
Action: At any given state \(s_t\), an agent needs an action to move to state \(s_{t+1}\). The goal of any RL problem is to find a series of actions that move the agent from an undesirable state to a desirable state (by maximizing expected reward through all the states).
For the purposes of this article, we do not provide a more comprehensive overview of a reinforcement learning framework. I would highly recommend reading [this blog], which provides such an overview.
Reinforcement Learning for CFEs
Our recent work (appearing soon at AAAI) shows how to use a reinforcement learning framework to generate real-time counterfactual explanations. Let us consider a very simple dataset: \(x_1, x_2, y\). We have some binary classifier that creates a decision boundary shown in Figure 1. Our goal is to move any arbitrary point on the left side of the boundary, namely points in blue, across the boundary to be classified as the opposite class (orange points).
Figure 1: Decision Boundary for Arbitrary Dataset. Note the binary classifier is not perfect and mislabels a few points, namely the blue points on the right side of the graph.
To frame this problem as a reinforcement learning problem, we define the following:
Agent: In our counterfactual scenario, our agent can be considered to be located on any point on the left side of the decision boundary (any of the blue points). The agent needs to take a series of actions (which we will define below), to move from the left side of the decision boundary to the right side.
Environment: Our environment (which represents our state-space) is defined as the entire grid (which is discretized into a finite number of states for every input in our feature space). Our environment will send a reward to agents at state \(s\). The reward function is defined as the sum of two different components:
Classification Reward: We want to reward agents that are on the correct side of the decision boundary. Therefore, we give a high reward once an agent is across the decision boundary. We use the binary classifier’s predict function to determine if an agent is close to the boundary (if the class probability is close to the decision threshold). Therefore the closer to the decision boundary the agent is in, the higher reward it will receive and vice-versa. This value will always be between 0 and 1.
Dataset Manifold Reward: We want our final counterfactual point to be similar to the training dataset. Therefore, we give higher negative reward for points that are dissimilar to the training data. Computing this similar-ness of states can be implemented in several ways, such as auto-encoders, K-Nearest-Neighbors, etc. For this implementation we used KNN distance (normalized between 0 and 1) to measure the similar-ness of the final counterfactual point.
Action: In our discretized environment, we allow our agent to move a small distance along any one of the feature dimensions. In our example above, an action \(a\) for an agent at state \(s\), will be a small \((±0.05)\) change for either feature \(x_1\) or \(x_2\). We opted for these small discretized movement to limit to number of movements an agent can perform at any given state. We also believe that an agent will incrementally learn the right movements in order to move towards the decision boundary.
Figure 2 shows a possible path that a point can take to find a CFE.
Figure 2: Sample path that an agent can take in our environment.
Open AI Gym
Open AI Gym provides a framework that allows us to create an environment for an agent to interact with. This is the default standard for defining RL environments (and already comes with a set of pre-defined environments for different tasks), however you have the ability to define your own environment. In order for us to learn a model that achieves our RL task, we must fit it into the Open AI Gym framework.
1. Observation Space: You must define what every state will look like for your RL environment. In Python, all observation spaces must be defined by one of the following types: Tuple, Discrete, Box, Dict, Multi-Binary, Multi-Discrete (view here).
# We create an observation space as a Box (np array) with limits (-1, 1)
self.obeservation_space = gym.spaces.Box(low=np.ones(shape=len(num_features)) * -1,
high=np.ones(shape=len(num_features)))
2. Action Space: You must also define what an action will look for any given agent. It is similarly defined by the same types as observation spaces (view here).
# We create an action space as a tuple, where the first value is a feature
# index and the second value is binary (increasing or decreasing)
self.action_space = gym.spaces.Tuple((
gym.spaces.Discrete(num_features),
gym.spaces.Discrete(2)
))
3. Step Function: We define a function that when given an action (from the action space) and a state, it will be able to return the reward produced by taking this action. It will also return the new state the agent moved because of this action (this may or may not be deterministic).
def step(self, action: Tuple[int, int]) -> Tuple[tuple, float, bool, dict]:
"""
Step function for the Gym environment
:param: action: action in the gym environment (transformed_feature_index, increase/decrease)
:return: state: an observation space defined above
:return: reward: reward for moving to that state
:return: done: (bool) is the process complete
:return: info: dict info about the environment at every step (useful for debugging)
"""
# Get action and if we should increase or decrease
feature_index = action[0]
decrease = bool(action[1])
# Set default reward to be negative
reward = -10.0
done = False
constant_cost = 0.0
# Checks to make sure we are not changing an immutable feature
if feature_index in immutable_features:
return self.state, reward, done
# Check to make sure we are not decreasing a non-decreasing features
if feature_index in non_decreasing_features and decrease:
return self.state, reward, done
# Move the agent (X_train should be normalized between 0 and 1)
new_state = self.state
amount = -0.05 if decrease else 0.05
new_state[feature] += amount
self.state = new_state
# Compute classifier reward (if we crossed the decision boundary this number
# will be very large)
classifier_reward_, done = self.classifier_reward(self.state)
# Compute KNN reward
manifold_dist_loss =self.distance_to_data_manifold(self.state)
# Compute total reward
reward = classifier_reward - manifold_dist_loss
return self.state, reward, done, info
4. Reset Function: Once the agent has reach a desired state (or we want to start over), we need to be able to reset to some starting state. We must define the policy for determining the starting state in the reset function.
def reset(self, initial_state: pd.Series = None) -> tuple:
"""
Reset methods for Gym environment. Called after an episode is complete or called for starting evaluation with a
specified starting initial state
:param: initial_state: initial starting state (used for evaluation)
"""
if initial_state is not None:
# This is done for inference
self.state = initial_state
else:
# Randomly get a data point from our training dataset
self.state = self.X_train.sample()
return self.state
We can now formally define our FastCFE (fast counterfactual explanations) class, which has all the above components defined in one python class.
class FastCFE(gym.Env):
def __init__(self,
classifier_predict: Callable[[pd.Series], np.ndarray],
X_train: pd.DataFrame,
categorical_features: List[str],
immutable_features: List[str] = None,
non_decreasing_features: List[str] = None):
"""
Initializes the Gym Environment to train the RL Model for Counter-Factual-Explanations (CFE)
(https://arxiv.org/pdf/2106.03962.pdf)
"""
# Initialize all global class variables
self.classifier = classifier_predict
self.X_train = X_train
self.categorical_features = categorical_features
self.immutable_features = immutable_features
self.non_decreasing_features = non_decreasing_features
# Create the state and action space
# We create an observation space as a Box (np array) with limits (-1, 1)
self.obeservation_space = gym.spaces.Box(low=np.ones(shape=len(X_train.columns)) * -1,
high=np.ones(shape=len(X_train.columns)))
# We create an action space as a tuple, where the first value is a feature
# index and the second value is binary (increasing or decreasing)
self.action_space = gym.spaces.Tuple((gym.spaces.Discrete(len(X_train.columns)),
gym.spaces.Discrete(2)))
def step(self, action: Tuple[int, int]) -> Tuple[tuple, float, bool, dict]:
"""
Step function for the Gym environment
:param: action: action in the gym environment (transformed_feature_index, increase/decrease)
:return: state: an observation space defined above
:return: reward: reward for moving to that state
:return: done: (bool) is the process complete
:return: info: dict info about the environment at every step (useful for debugging)
"""
# Get action and if we should increase or decrease
feature_index = action[0]
decrease = bool(action[1])
# Set default reward to be negative
reward = -10.0
done = False
constant_cost = 0.0
# Checks to make sure we are not changing an immutable feature
if feature_index in immutable_features:
return self.state, reward, done
# Check to make sure we are not decreasing a non-decreasing features
if feature_index in non_decreasing_features and decrease:
return self.state, reward, done
# Move the agent (X_train should be normalized between 0 and 1)
new_state = self.state
amount = -0.05 if decrease else 0.05
new_state[feature] += amount
self.state = new_state
# Compute classifier reward (if we crossed the decision boundary this number
# will be very large)
classifier_reward_, done = self.classifier_reward(self.state)
# Compute KNN reward
manifold_dist_loss =self.distance_to_data_manifold(self.state)
# Compute total reward
reward = classifier_reward - manifold_dist_loss
return self.state, reward, done, info
def reset(self, initial_state: pd.Series = None) -> tuple:
"""
Reset methods for Gym environment. Called after an episode is complete or called for starting evaluation with a
specified starting initial state
:param: initial_state: initial starting state (used for evaluation)
"""
if initial_state is not None:
# This is done for inference
self.state = initial_state
else:
# Randomly get a data point from our training dataset
self.state = self.X_train.sample()
return self.state
Training the Model using Ray + Rllib
Now that we have created our environment, actions, reward, state-space, and have properly defined our OpenAI Gym environment, we must now produce a model that produces real-time counterfactual explanations. We must produce an optimal policy, \({\pi}(s)\), which will produce an action for an agent in state \(s\) that will maximize future expected reward (reaching our desirable state across the decision boundary). In reinforcement learning, there are several ways to find the optimal policy \({\pi}\), ranging from model-free to model-based optimization techniques. We recommend this link to review more about training reinforcement learning algorithms.
Ray and Rllib
We were focused on finding an optimization algorithm and framework that is fast, scale-able, and easy to use. Much of the recent work has been focused on Distributed Deep Reinforcement Learning, which uses neural networks to implicitly learn the optimal policy \({\pi}\). One such framework is [Ray + Rllib]:
Ray is a new distributed framework in python designed to distribute training tasks across any cluster. Rllib is a specific package within Ray that is designed to train different RL agents for different environments. Rllib has a variety of different optimization algorithms, and provides a configurable dictionary that allows us to distribute training and evaluation to different cores and machines very easily. Furthermore, it provides an easy API that allows use to modify the internal of the Deep RL Optimizer. Rllib is maintained by Anyscale (founded out of Berkeley RiseLab) and it one of the state-of-the art frameworks for distributed computing/machine learning.
We opted to use the Proximal Policy Optimizer (a Deep RL Algorithm) because of its favorable balance of faster training times and simplicity. We needed an algorithm that would train relatively fast and that could distribute fairly simply, both of which PPO provides out of the box. We provide our pseudo-code that provides a FastCFE specific wrapper around a native Rllib PPO optimizer.
For our implementation, we opted to use the Proximal Policy Optimizer (a Deep RL Algorithm), which outperforms other online policy gradient methods, and overall strikes a favorable balance between sample complexity, simplicity, and wall-time [Schulman et. al]. Rllib provides an out of the box way to use PPO and distribute it across a cluster and on your laptop. Below, we showcase some pseudo for how we wrapped our optimizer around Rllib to provide an easy package to use Rllib and our FastCFE model:
class RayLib():
def __init__(self, env: FastCFE = None) -> None:
"""
Initializes a RayLib optimizer for PPO (can add functionally for other optimizers)
:param: env: the FastCFE class which contains our gym environment
"""
if env is not None:
super().__init__(env)
self.env = env
self.model: Optional[ppo.PPOTrainer] = None
self.config: Optional[Dict[str, Any]] = None
def train(self) -> None:
"""
Train the RL Agent for the CFE environment (with hyperparameters)
View more of the features here: (https://docs.ray.io/en/master/rllib-training.html)
"""
hyperparameters = ppo.DEFAULT_CONFIG.copy()
# Register the CFE environment to be trained in Ray
def env_creator(env: FastCFE):
return env
register_env("my_env", env_creator)
# Set the environment config variables
hyperparameters['env_config'] = self.env
# Create the PPO Trainer with the configuration and train
self.model = ppo.PPOTrainer(env="my_env", config=self.config)
print("Training RL Agent.....")
for i in range(10):
result = self.model.train()
print(f"Training Iteration: {i+1}/{num_epochs}")
print(f"Total number of datapoints sampled: {result['episodes_total']}")
print(f"Mean Reward (Max of 100): {result['episode_reward_mean']}")
print(f"Avg. steps to produce counterfactual (Min: 1, Max: 20): {result['episode_len_mean']}")
print("_______________________________________________________________________________________")
return None
def predict(self, initial_state: pd.Series, max_steps: int = None) -> Tuple[List[pd.Series], float, bool]:
"""
Make a counterfactual prediction
:param: initial_state: initial_state
:param: mode: scaled or unscaled path
"""
if self.model:
max_steps = 50
num = 1
next_state: Tuple = self.env.reset(initial_state=initial_state)
done: bool = False
reward: float = 0.0
steps = []
while num < max_steps:
action = self.model.compute_action(next_state)
next_state, reward, done, info = self.env.step(action)
steps.append(next_state)
if done:
break
num += 1
return steps, reward, done
else:
raise ValueError("Model not loaded")
Benchmarks
We want to showcase two major benchmarks for this algorithm and implementation:
Performance Metrics: We want to understand how our FastCFE approach using RL compares against other methods for performance metrics (described later).
Training Time: We want to see how much savings we achieved when using a Distributed RL Framework like Ray against a single-threaded research implementation.
These benchmarks for implemented for a variety of combinations of the following credit risk datasets. All of these datasets contained some from of credit risk data and the models are binary classifiers that predict a single applicant (one row) should be accepted or rejected for a loan. We want to find counterfactual explanations for all the rejected applicants. The sizes of the datasets are shown below (number of rows by number of columns):
German Credit: 1,000 data-points x 21 features
Adult Credit: 48,842 data-points x 14 features
Credit Default: 30,000 data-points x 25 features
Small Encoded Credit Dataset: 332,000 data-points x 44 features
Large Encoded Credit Dataset: 332,000 data-points x 652 features
The first three datasets (German Credit, Adult Credit, Credit Default) are all open source datasets with the links provided above. The last two datasets are proprietary datasets with obfuscated column names. These datasets were larger and tested the scaleable of our implementation.
Performance Metrics
We want to see how our FastCFE model compares to other well-known methods. Specifically, we are focusing on the following two metrics:
Validity: This is the total number of counterfactual explanations found divided by the total number of data points. This is represented as a percentage.
Mean Inference Time: This is the mean time it takes to calculate a batch of inferences. Namely is the time it takes to compute \(n\) inferences divided by \(n\).
The results shown below are FastCFE against a number of state of the art counterfactual explanation methods:
Adult Credit
Validity(%)
Mean Inference Time (s)
Dice-Genetic
98.1
1.71
Dice-Random
100
0.17
MACE LR
100
38.45
MACE RF
100
101.29
FastCFE
97.3
0.07
German Credit
Validity(%)
Mean Inference Time (s)
Dice-Genetic
89.5
3.45
Dice-Random
100
0.22
Dice-Gradient
84
59.75
FastCFE
100
0.015
Credit Default
Validity(%)
Mean Inference Time (s)
Dice-Genetic
92.6
3.58
Dice-Random
100
0.39
Dice-Gradient
81.0
479.17
FastCFE
99.9
0.051
As we see here, we perform nearly as close as the best method (Dice-Random) across these three different datasets. Furthermore, we have inference times of up to 20x faster than Dice-Random.
Training Time
The first implementation of this project was done using a package called StablesBaseline3, and was naively computed trained on a single machine. This section wants to showcase the change in training time of our scaleable and distributed Rllib implementation against a naive implementation. The results are shown below:
Rllib Train Time (hrs)
Naive Train Time (hrs)
German Credit
.25
1.5
Credit Default
1
6
Small Encoded
1.5
8
Large Encoded
8
DNF
We achieve a nearly 6x savings in train time and can handle much larger datasets than we could through our naive implementation. This shows the promise and power of using a scaleable and distributed reinforcement learning framework — we can significantly reduce training times which is a major bottleneck for several reinforcement learning applications.
Conclusion
We hope this article provided an overview into the following ideas/concepts:
Counterfactual Explanations: What they are and how they are useful for industrial and explainability applications
Reinforcement Learning Implementation: How we implement a production level reinforcement learning model.
Power of Distribution: How we can achieve tremendous savings by using scalable and distributed reinforcement learning frameworks such as Rllib.
We hope that we provided some interesting ideas and some starter code to help you make your own Reinforcement Learning model. If you would like to learn more about this article, please reach out!