Generate feasible counterfactual explanations using a VAE
This presents the variational inference based approach for generating feasible counterfactuals, where we first train an encoder-decoder framework to generate counterfactuals. More details about our framework can be found here: https://arxiv.org/abs/1912.03277
[1]:
# import DiCE
import dice_ml
from dice_ml.utils import helpers # helper functions
%load_ext autoreload
%autoreload 2
DiCE requires two inputs: a training dataset and a pre-trained ML model. It can also work without access to the full dataset (see this notebook for advanced examples).
Loading dataset
We use the “adult” income dataset from UCI Machine Learning Repository (https://archive.ics.uci.edu/ml/datasets/adult). For demonstration purposes, we transform the data as described in dice_ml.utils.helpers module.
[2]:
dataset = helpers.load_adult_income_dataset()
---------------------------------------------------------------------------
RemoteDisconnected Traceback (most recent call last)
Input In [2], in <cell line: 1>()
----> 1 dataset = helpers.load_adult_income_dataset()
File /mnt/c/Users/amshar/code/dice/dice_ml/utils/helpers.py:25, in load_adult_income_dataset(only_train)
19 def load_adult_income_dataset(only_train=True):
20 """Loads adult income dataset from https://archive.ics.uci.edu/ml/datasets/Adult and prepares
21 the data for data analysis based on https://rpubs.com/H_Zhu/235617
22
23 :return adult_data: returns preprocessed adult income dataset.
24 """
---> 25 raw_data = np.genfromtxt('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data',
26 delimiter=', ', dtype=str, invalid_raise=False)
28 # column names from "https://archive.ics.uci.edu/ml/datasets/Adult"
29 column_names = ['age', 'workclass', 'fnlwgt', 'education', 'educational-num', 'marital-status', 'occupation',
30 'relationship', 'race', 'gender', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
31 'income']
File ~/python-envs/v3.8dowhy/lib/python3.8/site-packages/numpy/lib/npyio.py:1934, in genfromtxt(fname, dtype, comments, delimiter, skip_header, skip_footer, converters, missing_values, filling_values, usecols, names, excludelist, deletechars, replace_space, autostrip, case_sensitive, defaultfmt, unpack, usemask, loose, invalid_raise, max_rows, encoding, ndmin, like)
1932 fname = os_fspath(fname)
1933 if isinstance(fname, str):
-> 1934 fid = np.lib._datasource.open(fname, 'rt', encoding=encoding)
1935 fid_ctx = contextlib.closing(fid)
1936 else:
File ~/python-envs/v3.8dowhy/lib/python3.8/site-packages/numpy/lib/_datasource.py:193, in open(path, mode, destpath, encoding, newline)
156 """
157 Open `path` with `mode` and return the file object.
158
(...)
189
190 """
192 ds = DataSource(destpath)
--> 193 return ds.open(path, mode, encoding=encoding, newline=newline)
File ~/python-envs/v3.8dowhy/lib/python3.8/site-packages/numpy/lib/_datasource.py:525, in DataSource.open(self, path, mode, encoding, newline)
522 raise ValueError("URLs are not writeable")
524 # NOTE: _findfile will fail on a new file opened for writing.
--> 525 found = self._findfile(path)
526 if found:
527 _fname, ext = self._splitzipext(found)
File ~/python-envs/v3.8dowhy/lib/python3.8/site-packages/numpy/lib/_datasource.py:369, in DataSource._findfile(self, path)
366 filelist = filelist + self._possible_names(path)
368 for name in filelist:
--> 369 if self.exists(name):
370 if self._isurl(name):
371 name = self._cache(name)
File ~/python-envs/v3.8dowhy/lib/python3.8/site-packages/numpy/lib/_datasource.py:479, in DataSource.exists(self, path)
477 if self._isurl(path):
478 try:
--> 479 netfile = urlopen(path)
480 netfile.close()
481 del(netfile)
File /usr/lib/python3.8/urllib/request.py:222, in urlopen(url, data, timeout, cafile, capath, cadefault, context)
220 else:
221 opener = _opener
--> 222 return opener.open(url, data, timeout)
File /usr/lib/python3.8/urllib/request.py:525, in OpenerDirector.open(self, fullurl, data, timeout)
522 req = meth(req)
524 sys.audit('urllib.Request', req.full_url, req.data, req.headers, req.get_method())
--> 525 response = self._open(req, data)
527 # post-process response
528 meth_name = protocol+"_response"
File /usr/lib/python3.8/urllib/request.py:542, in OpenerDirector._open(self, req, data)
539 return result
541 protocol = req.type
--> 542 result = self._call_chain(self.handle_open, protocol, protocol +
543 '_open', req)
544 if result:
545 return result
File /usr/lib/python3.8/urllib/request.py:502, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args)
500 for handler in handlers:
501 func = getattr(handler, meth_name)
--> 502 result = func(*args)
503 if result is not None:
504 return result
File /usr/lib/python3.8/urllib/request.py:1397, in HTTPSHandler.https_open(self, req)
1396 def https_open(self, req):
-> 1397 return self.do_open(http.client.HTTPSConnection, req,
1398 context=self._context, check_hostname=self._check_hostname)
File /usr/lib/python3.8/urllib/request.py:1358, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args)
1356 except OSError as err: # timeout error
1357 raise URLError(err)
-> 1358 r = h.getresponse()
1359 except:
1360 h.close()
File /usr/lib/python3.8/http/client.py:1348, in HTTPConnection.getresponse(self)
1346 try:
1347 try:
-> 1348 response.begin()
1349 except ConnectionError:
1350 self.close()
File /usr/lib/python3.8/http/client.py:316, in HTTPResponse.begin(self)
314 # read until we get a non-100 response
315 while True:
--> 316 version, status, reason = self._read_status()
317 if status != CONTINUE:
318 break
File /usr/lib/python3.8/http/client.py:285, in HTTPResponse._read_status(self)
281 print("reply:", repr(line))
282 if not line:
283 # Presumably, the server closed the connection before
284 # sending a valid response.
--> 285 raise RemoteDisconnected("Remote end closed connection without"
286 " response")
287 try:
288 version, status, reason = line.split(None, 2)
RemoteDisconnected: Remote end closed connection without response
This dataset has 8 features. The outcome is income which is binarized to 0 (low-income, <=50K) or 1 (high-income, >50K).
[3]:
dataset.head()
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Input In [3], in <cell line: 1>()
----> 1 dataset.head()
NameError: name 'dataset' is not defined
[4]:
# description of transformed features
adult_info = helpers.get_adult_data_info()
adult_info
[4]:
{'age': 'age',
'workclass': 'type of industry (Government, Other/Unknown, Private, Self-Employed)',
'education': 'education level (Assoc, Bachelors, Doctorate, HS-grad, Masters, Prof-school, School, Some-college)',
'marital_status': 'marital status (Divorced, Married, Separated, Single, Widowed)',
'occupation': 'occupation (Blue-Collar, Other/Unknown, Professional, Sales, Service, White-Collar)',
'race': 'white or other race?',
'gender': 'male or female?',
'hours_per_week': 'total work hours per week',
'income': '0 (<=50K) vs 1 (>50K)'}
Given this dataset, we construct a data object for DiCE. Since continuous and discrete features have different ways of perturbation, we need to specify the names of the continuous features. DiCE also requires the name of the output variable that the ML model will predict.
[5]:
d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'],
outcome_name='income', data_name='adult', test_size=0.1)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Input In [5], in <cell line: 1>()
----> 1 d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'],
2 outcome_name='income', data_name='adult', test_size=0.1)
NameError: name 'dataset' is not defined
Loading the ML model
Below, we use a pre-trained ML model which produces high accuracy comparable to other baselines. For convenience, we include the sample trained model with the DiCE package.
Note that we need to specify the explainer in the model backend. This is because both model and explainer need to be using the same backend library (pytorch or tensorflow).
[6]:
backend = {'model': 'pytorch_model.PyTorchModel',
'explainer': 'feasible_base_vae.FeasibleBaseVAE'}
ML_modelpath = helpers.get_adult_income_modelpath(backend='PYT')
ML_modelpath = ML_modelpath[:-4] + '_2nodes.pth'
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
m.load_model()
print('ML Model', m.model)
/mnt/c/Users/amshar/code/dice/dice_ml/model.py:34: UserWarning: {'model': 'pytorch_model.PyTorchModel', 'explainer': 'feasible_base_vae.FeasibleBaseVAE'} backend not in supported backends sklearn,TF1,TF2,PYT
warnings.warn('{0} backend not in supported backends {1}'.format(
ML Model Sequential(
(0): Linear(in_features=29, out_features=20, bias=True)
(1): ReLU()
(2): Linear(in_features=20, out_features=2, bias=True)
(3): Softmax(dim=None)
)
/home/amshar/python-envs/v3.8dowhy/lib/python3.8/site-packages/torch/serialization.py:786: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/amshar/python-envs/v3.8dowhy/lib/python3.8/site-packages/torch/serialization.py:786: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/amshar/python-envs/v3.8dowhy/lib/python3.8/site-packages/torch/serialization.py:786: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/amshar/python-envs/v3.8dowhy/lib/python3.8/site-packages/torch/serialization.py:786: SourceChangeWarning: source code of class 'torch.nn.modules.activation.Softmax' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
Generate counterfactuals using a VAE model
Based on the data object d and the model object m, we can now instantiate the DiCE class for generating explanations. We present the variational inference based approach towards generating counterfactuals, where we first train an encoder-decoder framework to generate counterfactuals.
FeasibleBaseVAE class has an method train()
, which would train the Variational Encoder Decoder framework on the input dataframe. It has another arugment, pre_trained
, which if set to 0 would re-train the framework each time while generating CFs. Else, it can be set to 1 to avoid repeated training of the framework and would load the latest fitted VAE model.
[7]:
# initiate DiCE
exp = dice_ml.Dice(d, m, encoded_size=10, lr=1e-2,
batch_size=2048, validity_reg=42.0, margin=0.165, epochs=25,
wm1=1e-2, wm2=1e-2, wm3=1e-2)
exp.train(pre_trained=1)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Input In [7], in <cell line: 2>()
1 # initiate DiCE
----> 2 exp = dice_ml.Dice(d, m, encoded_size=10, lr=1e-2,
3 batch_size=2048, validity_reg=42.0, margin=0.165, epochs=25,
4 wm1=1e-2, wm2=1e-2, wm3=1e-2)
5 exp.train(pre_trained=1)
NameError: name 'd' is not defined
DiCE is a form of a local explanation and requires an query input whose outcome needs to be explained. Below we provide a sample input whose outcome is 0 (low-income) as per the ML model object m.
[8]:
# query instance in the form of a dictionary; keys: feature name, values: feature value
query_instance = {'age': 41,
'workclass': 'Private',
'education': 'HS-grad',
'marital_status': 'Single',
'occupation': 'Service',
'race': 'White',
'gender': 'Female',
'hours_per_week': 45}
Given the query input, we can now generate counterfactual explanations to show perturbed inputs from the original input where the ML model outputs class 1 (high-income).
[9]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite")
# visualize the results
dice_exp.visualize_as_dataframe(show_only_changes=True)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Input In [9], in <cell line: 2>()
1 # generate counterfactuals
----> 2 dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite")
3 # visualize the results
4 dice_exp.visualize_as_dataframe(show_only_changes=True)
NameError: name 'exp' is not defined
That’s it! You can try generating counterfactual explanations for other examples using the same code. You can compare the running time of this VAE-based to DiCE’s default method: VAE-based method is super fast!
Adding feasibility constraints
However, you might notice that for some examples, the above method can still return infeasible counterfactuals. This requires our base framework to be adpated for prodcuing feasible counterfactuals. A detailed description of how we adapt the method under different assumptions is provided in this paper.
In the section below, we show an adaptation our base approach for preserving the Age-Ed constraint: Age and Education can never decrease and increasing Education implies increase in Age. This approach is called ModelApprox, where we adapt our base approach for simple unary and binary constraints.
ModelApprox
Similar to the FeasibleBaseVAE class above, FeasibleModelApprox class has a method train()
with argument pre_trained
, which determines whether to train the framework again or load the latest optimal model. However, there are additional arguments to the train()
method:
The first arugment determines whether the constraint to be preserved is unary or monotonic
The second arugment provides the list of constraint variable names: [[Effect, Cause_1,..,Cause_n]]. In the case of a unary constraint, there would be no causes but only a single constrained variable.
The third argument provides the intended direction of change for the constrained variables: Value of 1 means that we allow for only increase in the constrained variable on the change from data point to its counterfactual and vice versa.
The fourth argument refers to the penalty weight for infeasibility under given constraint.
Initilize the Model and Explainer for FeasibleModelApprox
[10]:
backend = {'model': 'pytorch_model.PyTorchModel',
'explainer': 'feasible_model_approx.FeasibleModelApprox'}
ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
ML_modelpath = ML_modelpath[:-4] + '_2nodes.pth'
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
m.load_model()
print('ML Model', m.model)
ML Model Sequential(
(0): Linear(in_features=29, out_features=20, bias=True)
(1): ReLU()
(2): Linear(in_features=20, out_features=2, bias=True)
(3): Softmax(dim=None)
)
/mnt/c/Users/amshar/code/dice/dice_ml/model.py:34: UserWarning: {'model': 'pytorch_model.PyTorchModel', 'explainer': 'feasible_model_approx.FeasibleModelApprox'} backend not in supported backends sklearn,TF1,TF2,PYT
warnings.warn('{0} backend not in supported backends {1}'.format(
[11]:
# initiate DiCE
exp = dice_ml.Dice(d, m, encoded_size=10, lr=1e-2, batch_size=2048,
validity_reg=76.0, margin=0.344, epochs=25,
wm1=1e-2, wm2=1e-2, wm3=1e-2)
exp.train(1, [[0]], 1, 87, pre_trained=1)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Input In [11], in <cell line: 2>()
1 # initiate DiCE
----> 2 exp = dice_ml.Dice(d, m, encoded_size=10, lr=1e-2, batch_size=2048,
3 validity_reg=76.0, margin=0.344, epochs=25,
4 wm1=1e-2, wm2=1e-2, wm3=1e-2)
5 exp.train(1, [[0]], 1, 87, pre_trained=1)
NameError: name 'd' is not defined
[12]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite")
# visualize the results
dice_exp.visualize_as_dataframe(show_only_changes=True)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Input In [12], in <cell line: 2>()
1 # generate counterfactuals
----> 2 dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite")
3 # visualize the results
4 dice_exp.visualize_as_dataframe(show_only_changes=True)
NameError: name 'exp' is not defined
The results for ModelApprox show that the Age is also increased with increase in Education in counterfactual explanations unlike the BaseVAE method. You can try to experiment with ModelApprox to preserve unary and monotonic constraints for other datasets too. Examples for even more advanced approaches like SCMGenCF,OracleGenCF would be included soon to this repository, where we learn to generate feasible counterfactuals for complex feasiblity constraints. More details can be found in our paper.