Generate feasible counterfactual explanations using a VAE

This presents the variational inference based approach for generating feasible counterfactuals, where we first train an encoder-decoder framework to generate counterfactuals. More details about our framework can be found here: https://arxiv.org/abs/1912.03277

[1]:
# import DiCE
import dice_ml
from dice_ml.utils import helpers  # helper functions

%load_ext autoreload
%autoreload 2

DiCE requires two inputs: a training dataset and a pre-trained ML model. It can also work without access to the full dataset (see this notebook for advanced examples).

Loading dataset

We use the “adult” income dataset from UCI Machine Learning Repository (https://archive.ics.uci.edu/ml/datasets/adult). For demonstration purposes, we transform the data as described in dice_ml.utils.helpers module.

[2]:
dataset = helpers.load_adult_income_dataset()
---------------------------------------------------------------------------
RemoteDisconnected                        Traceback (most recent call last)
Input In [2], in <cell line: 1>()
----> 1 dataset = helpers.load_adult_income_dataset()

File /mnt/c/Users/amshar/code/dice/dice_ml/utils/helpers.py:25, in load_adult_income_dataset(only_train)
     19 def load_adult_income_dataset(only_train=True):
     20     """Loads adult income dataset from https://archive.ics.uci.edu/ml/datasets/Adult and prepares
     21        the data for data analysis based on https://rpubs.com/H_Zhu/235617
     22
     23     :return adult_data: returns preprocessed adult income dataset.
     24     """
---> 25     raw_data = np.genfromtxt('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data',
     26                              delimiter=', ', dtype=str, invalid_raise=False)
     28     #  column names from "https://archive.ics.uci.edu/ml/datasets/Adult"
     29     column_names = ['age', 'workclass', 'fnlwgt', 'education', 'educational-num', 'marital-status', 'occupation',
     30                     'relationship', 'race', 'gender', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
     31                     'income']

File ~/python-envs/v3.8dowhy/lib/python3.8/site-packages/numpy/lib/npyio.py:1934, in genfromtxt(fname, dtype, comments, delimiter, skip_header, skip_footer, converters, missing_values, filling_values, usecols, names, excludelist, deletechars, replace_space, autostrip, case_sensitive, defaultfmt, unpack, usemask, loose, invalid_raise, max_rows, encoding, ndmin, like)
   1932     fname = os_fspath(fname)
   1933 if isinstance(fname, str):
-> 1934     fid = np.lib._datasource.open(fname, 'rt', encoding=encoding)
   1935     fid_ctx = contextlib.closing(fid)
   1936 else:

File ~/python-envs/v3.8dowhy/lib/python3.8/site-packages/numpy/lib/_datasource.py:193, in open(path, mode, destpath, encoding, newline)
    156 """
    157 Open `path` with `mode` and return the file object.
    158
   (...)
    189
    190 """
    192 ds = DataSource(destpath)
--> 193 return ds.open(path, mode, encoding=encoding, newline=newline)

File ~/python-envs/v3.8dowhy/lib/python3.8/site-packages/numpy/lib/_datasource.py:525, in DataSource.open(self, path, mode, encoding, newline)
    522     raise ValueError("URLs are not writeable")
    524 # NOTE: _findfile will fail on a new file opened for writing.
--> 525 found = self._findfile(path)
    526 if found:
    527     _fname, ext = self._splitzipext(found)

File ~/python-envs/v3.8dowhy/lib/python3.8/site-packages/numpy/lib/_datasource.py:369, in DataSource._findfile(self, path)
    366     filelist = filelist + self._possible_names(path)
    368 for name in filelist:
--> 369     if self.exists(name):
    370         if self._isurl(name):
    371             name = self._cache(name)

File ~/python-envs/v3.8dowhy/lib/python3.8/site-packages/numpy/lib/_datasource.py:479, in DataSource.exists(self, path)
    477 if self._isurl(path):
    478     try:
--> 479         netfile = urlopen(path)
    480         netfile.close()
    481         del(netfile)

File /usr/lib/python3.8/urllib/request.py:222, in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    220 else:
    221     opener = _opener
--> 222 return opener.open(url, data, timeout)

File /usr/lib/python3.8/urllib/request.py:525, in OpenerDirector.open(self, fullurl, data, timeout)
    522     req = meth(req)
    524 sys.audit('urllib.Request', req.full_url, req.data, req.headers, req.get_method())
--> 525 response = self._open(req, data)
    527 # post-process response
    528 meth_name = protocol+"_response"

File /usr/lib/python3.8/urllib/request.py:542, in OpenerDirector._open(self, req, data)
    539     return result
    541 protocol = req.type
--> 542 result = self._call_chain(self.handle_open, protocol, protocol +
    543                           '_open', req)
    544 if result:
    545     return result

File /usr/lib/python3.8/urllib/request.py:502, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args)
    500 for handler in handlers:
    501     func = getattr(handler, meth_name)
--> 502     result = func(*args)
    503     if result is not None:
    504         return result

File /usr/lib/python3.8/urllib/request.py:1397, in HTTPSHandler.https_open(self, req)
   1396 def https_open(self, req):
-> 1397     return self.do_open(http.client.HTTPSConnection, req,
   1398         context=self._context, check_hostname=self._check_hostname)

File /usr/lib/python3.8/urllib/request.py:1358, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args)
   1356     except OSError as err: # timeout error
   1357         raise URLError(err)
-> 1358     r = h.getresponse()
   1359 except:
   1360     h.close()

File /usr/lib/python3.8/http/client.py:1348, in HTTPConnection.getresponse(self)
   1346 try:
   1347     try:
-> 1348         response.begin()
   1349     except ConnectionError:
   1350         self.close()

File /usr/lib/python3.8/http/client.py:316, in HTTPResponse.begin(self)
    314 # read until we get a non-100 response
    315 while True:
--> 316     version, status, reason = self._read_status()
    317     if status != CONTINUE:
    318         break

File /usr/lib/python3.8/http/client.py:285, in HTTPResponse._read_status(self)
    281     print("reply:", repr(line))
    282 if not line:
    283     # Presumably, the server closed the connection before
    284     # sending a valid response.
--> 285     raise RemoteDisconnected("Remote end closed connection without"
    286                              " response")
    287 try:
    288     version, status, reason = line.split(None, 2)

RemoteDisconnected: Remote end closed connection without response

This dataset has 8 features. The outcome is income which is binarized to 0 (low-income, <=50K) or 1 (high-income, >50K).

[3]:
dataset.head()
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [3], in <cell line: 1>()
----> 1 dataset.head()

NameError: name 'dataset' is not defined
[4]:
# description of transformed features
adult_info = helpers.get_adult_data_info()
adult_info
[4]:
{'age': 'age',
 'workclass': 'type of industry (Government, Other/Unknown, Private, Self-Employed)',
 'education': 'education level (Assoc, Bachelors, Doctorate, HS-grad, Masters, Prof-school, School, Some-college)',
 'marital_status': 'marital status (Divorced, Married, Separated, Single, Widowed)',
 'occupation': 'occupation (Blue-Collar, Other/Unknown, Professional, Sales, Service, White-Collar)',
 'race': 'white or other race?',
 'gender': 'male or female?',
 'hours_per_week': 'total work hours per week',
 'income': '0 (<=50K) vs 1 (>50K)'}

Given this dataset, we construct a data object for DiCE. Since continuous and discrete features have different ways of perturbation, we need to specify the names of the continuous features. DiCE also requires the name of the output variable that the ML model will predict.

[5]:
d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'],
                 outcome_name='income', data_name='adult', test_size=0.1)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [5], in <cell line: 1>()
----> 1 d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'],
      2                  outcome_name='income', data_name='adult', test_size=0.1)

NameError: name 'dataset' is not defined

Loading the ML model

Below, we use a pre-trained ML model which produces high accuracy comparable to other baselines. For convenience, we include the sample trained model with the DiCE package.

Note that we need to specify the explainer in the model backend. This is because both model and explainer need to be using the same backend library (pytorch or tensorflow).

[6]:
backend = {'model': 'pytorch_model.PyTorchModel',
           'explainer': 'feasible_base_vae.FeasibleBaseVAE'}
ML_modelpath = helpers.get_adult_income_modelpath(backend='PYT')
ML_modelpath = ML_modelpath[:-4] + '_2nodes.pth'
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
m.load_model()
print('ML Model', m.model)
/mnt/c/Users/amshar/code/dice/dice_ml/model.py:34: UserWarning: {'model': 'pytorch_model.PyTorchModel', 'explainer': 'feasible_base_vae.FeasibleBaseVAE'} backend not in supported backends sklearn,TF1,TF2,PYT
  warnings.warn('{0} backend not in supported backends {1}'.format(
ML Model Sequential(
  (0): Linear(in_features=29, out_features=20, bias=True)
  (1): ReLU()
  (2): Linear(in_features=20, out_features=2, bias=True)
  (3): Softmax(dim=None)
)
/home/amshar/python-envs/v3.8dowhy/lib/python3.8/site-packages/torch/serialization.py:786: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/home/amshar/python-envs/v3.8dowhy/lib/python3.8/site-packages/torch/serialization.py:786: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/home/amshar/python-envs/v3.8dowhy/lib/python3.8/site-packages/torch/serialization.py:786: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/home/amshar/python-envs/v3.8dowhy/lib/python3.8/site-packages/torch/serialization.py:786: SourceChangeWarning: source code of class 'torch.nn.modules.activation.Softmax' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)

Generate counterfactuals using a VAE model

Based on the data object d and the model object m, we can now instantiate the DiCE class for generating explanations. We present the variational inference based approach towards generating counterfactuals, where we first train an encoder-decoder framework to generate counterfactuals.

FeasibleBaseVAE class has an method train(), which would train the Variational Encoder Decoder framework on the input dataframe. It has another arugment, pre_trained, which if set to 0 would re-train the framework each time while generating CFs. Else, it can be set to 1 to avoid repeated training of the framework and would load the latest fitted VAE model.

[7]:
# initiate DiCE
exp = dice_ml.Dice(d, m, encoded_size=10, lr=1e-2,
                   batch_size=2048, validity_reg=42.0, margin=0.165, epochs=25,
                   wm1=1e-2, wm2=1e-2, wm3=1e-2)
exp.train(pre_trained=1)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [7], in <cell line: 2>()
      1 # initiate DiCE
----> 2 exp = dice_ml.Dice(d, m, encoded_size=10, lr=1e-2,
      3                    batch_size=2048, validity_reg=42.0, margin=0.165, epochs=25,
      4                    wm1=1e-2, wm2=1e-2, wm3=1e-2)
      5 exp.train(pre_trained=1)

NameError: name 'd' is not defined

DiCE is a form of a local explanation and requires an query input whose outcome needs to be explained. Below we provide a sample input whose outcome is 0 (low-income) as per the ML model object m.

[8]:
# query instance in the form of a dictionary; keys: feature name, values: feature value
query_instance = {'age': 41,
                  'workclass': 'Private',
                  'education': 'HS-grad',
                  'marital_status': 'Single',
                  'occupation': 'Service',
                  'race': 'White',
                  'gender': 'Female',
                  'hours_per_week': 45}

Given the query input, we can now generate counterfactual explanations to show perturbed inputs from the original input where the ML model outputs class 1 (high-income).

[9]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite")
# visualize the results
dice_exp.visualize_as_dataframe(show_only_changes=True)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [9], in <cell line: 2>()
      1 # generate counterfactuals
----> 2 dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite")
      3 # visualize the results
      4 dice_exp.visualize_as_dataframe(show_only_changes=True)

NameError: name 'exp' is not defined

That’s it! You can try generating counterfactual explanations for other examples using the same code. You can compare the running time of this VAE-based to DiCE’s default method: VAE-based method is super fast!

Adding feasibility constraints

However, you might notice that for some examples, the above method can still return infeasible counterfactuals. This requires our base framework to be adpated for prodcuing feasible counterfactuals. A detailed description of how we adapt the method under different assumptions is provided in this paper.

In the section below, we show an adaptation our base approach for preserving the Age-Ed constraint: Age and Education can never decrease and increasing Education implies increase in Age. This approach is called ModelApprox, where we adapt our base approach for simple unary and binary constraints.

ModelApprox

Similar to the FeasibleBaseVAE class above, FeasibleModelApprox class has a method train() with argument pre_trained, which determines whether to train the framework again or load the latest optimal model. However, there are additional arguments to the train() method:

  1. The first arugment determines whether the constraint to be preserved is unary or monotonic

  2. The second arugment provides the list of constraint variable names: [[Effect, Cause_1,..,Cause_n]]. In the case of a unary constraint, there would be no causes but only a single constrained variable.

  3. The third argument provides the intended direction of change for the constrained variables: Value of 1 means that we allow for only increase in the constrained variable on the change from data point to its counterfactual and vice versa.

  4. The fourth argument refers to the penalty weight for infeasibility under given constraint.

Initilize the Model and Explainer for FeasibleModelApprox

[10]:
backend = {'model': 'pytorch_model.PyTorchModel',
           'explainer': 'feasible_model_approx.FeasibleModelApprox'}
ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
ML_modelpath = ML_modelpath[:-4] + '_2nodes.pth'
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
m.load_model()
print('ML Model', m.model)
ML Model Sequential(
  (0): Linear(in_features=29, out_features=20, bias=True)
  (1): ReLU()
  (2): Linear(in_features=20, out_features=2, bias=True)
  (3): Softmax(dim=None)
)
/mnt/c/Users/amshar/code/dice/dice_ml/model.py:34: UserWarning: {'model': 'pytorch_model.PyTorchModel', 'explainer': 'feasible_model_approx.FeasibleModelApprox'} backend not in supported backends sklearn,TF1,TF2,PYT
  warnings.warn('{0} backend not in supported backends {1}'.format(
[11]:
# initiate DiCE
exp = dice_ml.Dice(d, m, encoded_size=10, lr=1e-2, batch_size=2048,
                   validity_reg=76.0, margin=0.344, epochs=25,
                   wm1=1e-2, wm2=1e-2, wm3=1e-2)
exp.train(1, [[0]], 1, 87, pre_trained=1)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [11], in <cell line: 2>()
      1 # initiate DiCE
----> 2 exp = dice_ml.Dice(d, m, encoded_size=10, lr=1e-2, batch_size=2048,
      3                    validity_reg=76.0, margin=0.344, epochs=25,
      4                    wm1=1e-2, wm2=1e-2, wm3=1e-2)
      5 exp.train(1, [[0]], 1, 87, pre_trained=1)

NameError: name 'd' is not defined
[12]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite")
# visualize the results
dice_exp.visualize_as_dataframe(show_only_changes=True)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [12], in <cell line: 2>()
      1 # generate counterfactuals
----> 2 dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite")
      3 # visualize the results
      4 dice_exp.visualize_as_dataframe(show_only_changes=True)

NameError: name 'exp' is not defined

The results for ModelApprox show that the Age is also increased with increase in Education in counterfactual explanations unlike the BaseVAE method. You can try to experiment with ModelApprox to preserve unary and monotonic constraints for other datasets too. Examples for even more advanced approaches like SCMGenCF,OracleGenCF would be included soon to this repository, where we learn to generate feasible counterfactuals for complex feasiblity constraints. More details can be found in our paper.