Source code for rsmtool.rsmexplain

#!/usr/bin/env python
"""
Explain a SKLL model using SHAP explainers.

:author: Remo Nitschke (rnitschke@ets.org)
:author: Zhaoyang Xie (zxie@etscanada.ca)
:author: Nitin Madnani (nmadnani@ets.org)

:organization: ETS
"""

import glob
import json
import logging
import os
import pickle
import sys
from os import listdir
from os.path import abspath, basename, exists, join, normpath, splitext
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import shap
from skll.data import FeatureSet
from skll.learner import Learner
from wandb.sdk.lib import RunDisabled
from wandb.wandb_run import Run

from .configuration_parser import Configuration, configure
from .modeler import Modeler
from .preprocessor import FeaturePreprocessor
from .reader import DataReader
from .reporter import Reporter
from .utils.commandline import ConfigurationGenerator, setup_rsmcmd_parser
from .utils.constants import VALID_PARSER_SUBCOMMANDS
from .utils.conversion import parse_range
from .utils.logging import LogFormatter
from .utils.wandb import init_wandb_run, log_configuration_to_wandb


def select_examples(
    featureset: FeatureSet, range_size: Optional[Union[int, Tuple[int, int]]] = None
) -> Dict[int, str]:
    """
    Sample examples from the given featureset and return indices.

    Parameters
    ----------
    featureset: FeatureSet
        The SKLL FeatureSet object from which we are sampling.
    range_size: Optional[Union[int, Tuple[int, int]]]
        A user defined sample size or range. If ``None``, all examples in the
        featureset are selected. If it's a size (int), that many examples are
        randomly selected. If it's a tuple, the two integers in the tuple
        define the size of the range of examples that is selected.

    Returns
    -------
    Dict[int, str]
        Dictionary mapping the position of the selected examples to their IDs.
    """
    fs_ids = featureset.ids
    if range_size is None:
        selected_ids = fs_ids
    elif isinstance(range_size, int):
        selected_ids = shap.sample(fs_ids, range_size)
    elif isinstance(range_size, tuple):
        selected_ids = np.array(range_size)
    else:
        start, end = range_size
        # NOTE: include the end index in the selected examples since it's more intuitive
        selected_ids = fs_ids[start : end + 1]  # noqa: E203

    # make sure that ``selected_ids`` is the same data type as ``fs_ids``
    selected_ids = selected_ids.astype(fs_ids.dtype)

    # find the positions of the selected ids in the original featureset
    try:
        selected_positions = [np.where(fs_ids == id_)[0][0] for id_ in selected_ids]
    except IndexError:
        raise ValueError(
            "Samples could not be selected; please check your configuration file."
        ) from None

    # create and return a dictionary mapping the position to the IDs
    return dict(zip(selected_positions, selected_ids))


def mask(
    learner: Learner,
    featureset: FeatureSet,
    feature_range: Optional[Union[int, Tuple[int, int]]] = None,
) -> Tuple[Dict[int, str], np.ndarray]:
    """
    Sample examples from featureset used by learner.

    An example refers to a specific data instance in the data set.
    Selects examples based on either sub-sampling specific indices or randomly
    of a fixed size. Return the feature values for the selected examples as
    a numpy array.

    Parameters
    ----------
    learner : Learner
        SKLL Learner object that we wish to explain the predictions of.
    featureset : FeatureSet
        SKLL FeatureSet object from which to sample examples.
    feature_range : Optional[Union[int, Tuple[int, int]]]
        If this is an integer, create a random sub-sample of that size. If this
        is a tuple, sub-sample the range of examples using the two values
        in the tuple. If this is ``None``, use all of the examples without
        any sub-sampling.

    Returns
    -------
    Dict[int, str]
        Dictionary mapping the position of the selected examples to their IDs.
        This is useful for figuring out which specific examples were selected.
    numpy.ndarray
        A 2D numpy array containing sampled feature rows.
    """
    # get a sparse matrix with the features that were actually used
    features = learner.feat_selector.transform(
        learner.feat_vectorizer.transform(
            featureset.vectorizer.inverse_transform(featureset.features)
        )
    )

    # sample examples from the featureset
    selected_feature_map = select_examples(featureset, range_size=feature_range)

    # if the user specified a sample size or a range, use it; otherwise all
    # features will be selected
    if feature_range:
        positions = list(selected_feature_map.keys())
        features = features[positions, :]

    # convert to a dense array if not already one
    features = features.toarray() if not isinstance(features, np.ndarray) else features
    return selected_feature_map, features


[docs] def generate_explanation( config_file_or_obj_or_dict: Union[str, Configuration, Dict[str, Any], Path], output_dir: str, overwrite_output: bool = False, logger: Optional[logging.Logger] = None, wandb_run: Union[Run, RunDisabled, None] = None, ): """ Generate a shap.Explanation object. This function does all the heavy lifting. It loads the model, creates an explainer, and generates an explanation object. It then calls generate_report() in order to generate a SHAP report. Parameters ---------- config_file_or_obj_or_dict : Union[str, Configuration, Dict[str, Any], Path] Path to the experiment configuration file either as a string or as a ``pathlib.Path`` object. Users can also pass a ``Configuration`` object that is in memory or a Python dictionary with keys corresponding to fields in the configuration file. Given a configuration file, any relative paths in the configuration file will be interpreted relative to the location of the file. Given a ``Configuration`` object, relative paths will be interpreted relative to the ``configdir`` attribute, that _must_ be set. Given a dictionary, the reference path is set to the current directory. output_dir : str Path to the experiment output directory. overwrite_output : bool If ``True``, overwrite any existing output under ``output_dir``. Defaults to ``False``. logger : Optional[logging.Logger] A Logger object. If ``None`` is passed, get logger from ``__name__``. Defaults to ``None``. wandb_run : Union[wandb.wandb_run.Run, wandb.sdk.lib.RunDisabled, None] A wandb run object that will be used to log artifacts and tables. If ``None`` is passed, a new wandb run will be initialized if wandb is enabled in the configuration. Defaults to ``None``. Raises ------ FileNotFoundError If any file contained in ``config_file_or_obj_or_dict`` cannot be located. ValueError If both ``sample_range`` and ``sample_size`` are defined in the configuration file. """ logger = logger if logger else logging.getLogger(__name__) # make sure all necessary directories exist os.makedirs(output_dir, exist_ok=True) csvdir = abspath(join(output_dir, "output")) figdir = abspath(join(output_dir, "figure")) reportdir = abspath(join(output_dir, "report")) os.makedirs(csvdir, exist_ok=True) os.makedirs(figdir, exist_ok=True) os.makedirs(reportdir, exist_ok=True) # Raise an error if the specified output directory # already contains a non-empty `output` directory, unless # `overwrite_output` was specified, in which case we assume # that the user knows what she is doing and simply # output a warning saying that the report might # not be correct. non_empty_csvdir = exists(csvdir) and listdir(csvdir) if non_empty_csvdir: if not overwrite_output: raise IOError(f"'{output_dir}' already contains a non-empty 'output' directory.") else: logger.warning( f"{output_dir} already contains a non-empty 'output' directory. " f"The generated report might contain unexpected information from " f"a previous experiment." ) configuration = configure("rsmexplain", config_file_or_obj_or_dict) logger.info("Saving configuration file.") configuration.save(output_dir) # If wandb logging is enabled, and wandb_run is not provided, # start a wandb run and log configuration if wandb_run is None: wandb_run = init_wandb_run(configuration) log_configuration_to_wandb(wandb_run, configuration) # get the experiment ID experiment_id = configuration["experiment_id"] # check that only one of `sample_range`, `sample_size` or `sample_range` is specified has_sample_range = configuration.get("sample_range") is not None has_sample_size = configuration.get("sample_size") is not None has_sample_ids = configuration.get("sample_ids") is not None if sum([has_sample_range, has_sample_size, has_sample_ids]) > 1: raise ValueError( "You must specify one of 'sample_range', 'sample_size' or 'sample_ids'. " "Please refer to the `rsmexplain` documentation for more details. " ) # find the rsmtool experiment directory experiment_dir = DataReader.locate_files( configuration["experiment_dir"], configuration.configdir )[0] if not experiment_dir: raise FileNotFoundError(f"The directory {configuration['experiment_dir']} does not exist.") else: experiment_output_dir = normpath(join(experiment_dir, "output")) if not exists(experiment_output_dir): raise FileNotFoundError( f"The directory {experiment_dir} does not contain " f"the output of an rsmtool experiment." ) # find all the .model files in the experiment output directory model_files = glob.glob(join(experiment_output_dir, "*.model")) if not model_files: raise FileNotFoundError( f"The directory {experiment_output_dir} does not contain any rsmtool models." ) experiment_ids = [splitext(basename(mf))[0] for mf in model_files] if experiment_id not in experiment_ids: raise FileNotFoundError( f"{experiment_output_dir} does not contain a model " f'for the experiment "{experiment_id}". The following ' f"experiments are contained in this directory: {experiment_ids}" ) # check that the directory contains the file with feature names and info expected_feature_file_name = f"{experiment_id}_feature.csv" if not exists(join(experiment_output_dir, expected_feature_file_name)): raise FileNotFoundError( f"{experiment_output_dir} does not contain the " f"required file {expected_feature_file_name} that was " f"generated during model training." ) # read the original rsmtool configuration file, if it exists, and ensure # that we use its value of `standardize_features` and `truncate_outliers` # even if that means we have to override the values specified in the # rsmexplain configuration file expected_config_file_path = join(experiment_output_dir, f"{experiment_id}_rsmtool.json") if exists(expected_config_file_path): with open(expected_config_file_path, "r") as rsmtool_configfh: rsmtool_configuration = json.load(rsmtool_configfh) for option in ["standardize_features", "truncate_outliers"]: rsmtool_value = rsmtool_configuration[option] rsmexplain_value = configuration[option] if rsmexplain_value != rsmtool_value: logger.warning( f"overwriting current `{option}` value " f"({rsmexplain_value}) to match " f"value specified in original rsmtool experiment " f"({rsmtool_value})." ) configuration[option] = rsmtool_value # if the original experiment rsmtool does not exist, let the user know else: logger.warning( "cannot locate original rsmtool configuration; " "ensure that the values of `standardize_features` " "and `truncate_outliers` were the same as when running rsmtool." ) # load the background and explain data sets (background_data_path, explain_data_path) = DataReader.locate_files( [configuration["background_data"], configuration["explain_data"]], configuration.configdir, ) if not background_data_path: raise FileNotFoundError(f"Input file {configuration['background_data']} does not exist") if not explain_data_path: raise FileNotFoundError(f"Input file {configuration['explain_data']} does not exist") # read the background data, explain data, and feature info files feature_info_path = join(experiment_output_dir, f"{experiment_id}_feature.csv") file_paths = [background_data_path, explain_data_path, feature_info_path] file_names = [ "background_features", "explain_features", "feature_info", ] reader = DataReader(file_paths, file_names) container = reader.read(kwargs_dict={"feature_info": {"index_col": 0}}) # ensure that the background data is large enough for meaningful explanations background_data_size = len(container["background_features"]) if background_data_size < 300: logger.error( f"The background data {background_data_path} contains only " f"{background_data_size} examples. It must contain at least 300 examples " "to ensure meaningful explanations." ) sys.exit(1) # now pre-process the background and explain data features to match # what the model expects processor = FeaturePreprocessor(logger=logger) (_, processed_container) = processor.process_data( configuration, container, context="rsmexplain" ) # create featuresets from pre-processed background and explain features background_fs = FeatureSet.from_data_frame( processed_container["background_features_preprocessed"], "background" ) explain_fs = FeatureSet.from_data_frame( processed_container["explain_features_preprocessed"], "explain" ) # get the SKLL learner object for the rsmtool experiment and its feature names modeler = Modeler.load_from_file(join(experiment_output_dir, f"{experiment_id}.model")) learner = modeler.learner # at this point learner should be a valid SKLL learner but let's confirm # to satisfy mypy assert learner is not None feature_names = list(learner.get_feature_names_out()) # compute the background kmeans distribution _, all_background_features = mask(learner, background_fs) background_distribution = shap.kmeans( all_background_features, configuration["background_kmeans_size"] ) # get and parse the value of either the sample range or the sample size range_size: Optional[Union[int, Tuple[int, int]]] if has_sample_size: range_size = int(configuration.get("sample_size")) elif has_sample_range: range_size = parse_range(configuration.get("sample_range")) elif has_sample_ids: range_size_strs = configuration.get("sample_ids").split(",") range_size = tuple([id_.strip() for id_ in range_size_strs]) else: range_size = None logger.warning( "Since 'sample_range', 'sample_size' and 'sample_ids' are all unspecified, " "explanations will be generated for the *entire* data set which " "could be very slow, depending on its size. " ) # get the features we want to explain ids, data_features = mask(learner, explain_fs, feature_range=range_size) # define a shap explainer explainer = shap.explainers.Sampling( learner.model.predict, background_distribution, feature_names=feature_names, seed=np.random.seed(42), ) logger.info( f"Generating SHAP explanations for {len(ids)} " f"examples from {configuration['explain_data']}" ) explanation = explainer(data_features) # add feature names if they aren't already specified if explanation.feature_names is None: explanation.feature_names = feature_names # the explainer does not correctly generate base value arrays sometimes; # sometimes it's a single float or sometimes an array with a (1,) shape # so let's fix it if that happens base_values = explanation.base_values if not isinstance(base_values, np.ndarray): explanation.base_values = np.repeat(base_values, explanation.values.shape[0]) # re-generate the explanation here, because manually munging the feature # names and base values can break some plots # TODO: check if this is still necessary in future versions of shap explanation = shap.Explanation( explanation.values, base_values=explanation.base_values, data=explanation.data, feature_names=explanation.feature_names, ) # generate the HTML report generate_report(explanation, output_dir, ids, configuration, logger, wandb_run=wandb_run)
def generate_report( explanation: shap.Explanation, output_dir: str, ids: Dict[int, str], configuration: Configuration, logger: Optional[logging.Logger] = None, wandb_run: Union[Run, RunDisabled, None] = None, ) -> None: """ Generate an rsmexplain report. This function also saves a series of files to disk, including pickled versions of the explanation object and the ID dictionary. All SHAP values are also saved as CSV files. Parameters ---------- explanation: shap.Explanation SHAP explanation object containing SHAP values, data points, feature names and base values. output_dir : str Path to the experiment output directory. ids: Dict[int, str] Dictionary mapping new row indices to original FeatureSet ids. configuration: rsmtool.configuration_parser.Configuration The Configuration object for rsmexplain. logger : Optional[logging.Logger] A Logger object. If ``None`` is passed, get logger from ``__name__``. Defaults to ``None``. wandb_run : Union[wandb.wandb_run.Run, wandb.sdk.lib.RunDisabled, None] A wandb run object that will be used to log artifacts and tables. If ``None`` is passed, a new wandb run will be initialized if wandb is enabled in the configuration. Defaults to ``None``. """ logger = logger if logger else logging.getLogger(__name__) # get the various output sub-directories which should already exist csvdir = abspath(join(output_dir, "output")) reportdir = abspath(join(output_dir, "report")) # get the experiment ID experiment_id = configuration["experiment_id"] # first write the explanation object to disk, in case we need it later explanation_path = join(csvdir, f"{experiment_id}_explanation.pkl") with open(explanation_path, "wb") as pickle_out: pickle.dump(explanation, pickle_out) configuration["explanation"] = explanation_path id_path = join(csvdir, f"{experiment_id}_ids.pkl") with open(id_path, "wb") as pickle_out: pickle.dump(ids, pickle_out) configuration["ids"] = id_path # create various versions of the SHAP values to write to disk csv_path = join(csvdir, f"{experiment_id}_shap_values.csv") shap_frame = pd.DataFrame( explanation.values, columns=explanation.feature_names, index=ids.values() ) shap_frame.to_csv(csv_path) # compute the various absolute value variants of the SHAP values and # write out that dataframe to disk. csv_path_abs = join(csvdir, f"{experiment_id}_absolute_shap_values.csv") df_abs = pd.DataFrame( [shap_frame.abs().mean(), shap_frame.abs().max(), shap_frame.abs().min()], index=["abs. mean shap", "abs. max shap", "abs. min shap"], ).transpose() df_abs.to_csv(csv_path_abs, index_label="") # Initialize a reporter instance and add the sections: reporter = Reporter(logger=logger, wandb_run=wandb_run) general_report_sections = configuration["general_sections"] # get any custom sections and locate them to make sure # that they exist, otherwise raise an exception custom_report_section_paths = configuration["custom_sections"] if custom_report_section_paths: logger.info("Locating custom report sections") custom_report_sections = Reporter.locate_custom_sections( custom_report_section_paths, configuration.configdir ) else: custom_report_sections = [] # leverage custom sections to allow users to turn `show_auto_cohorts` on and off notebooks_path = Path(__file__).parent / "notebooks" notebooks_path = notebooks_path.resolve() explanation_notebooks_path = notebooks_path / "explanations" # check to see whether a single or multiple examples have been chosen has_single_example = len(explanation.values) <= 1 configuration["has_single_example"] = has_single_example # auto cohort plots will be displayed with more than one example selected if configuration["show_auto_cohorts"] and not has_single_example: custom_report_sections.append(f"{explanation_notebooks_path}/auto_cohorts.ipynb") # get user defined section order if available section_order = configuration["section_order"] # define all of the chosen notebook sections chosen_notebook_files = reporter.get_ordered_notebook_files( general_report_sections, custom_report_sections, section_order=section_order, context="rsmexplain", ) # add chosen notebook files to configuration and generate the report configuration["chosen_notebook_files"] = chosen_notebook_files reporter.create_explanation_report(configuration, csvdir, reportdir) def main(argv: Optional[List[str]] = None) -> None: """ Entry point for the ``rsmexplain`` command-line tool. Parameters ---------- argv : Optional[List[str]] List of arguments to use instead of ``sys.argv``. Defaults to ``None``. """ # if no arguments are passed, then use sys.argv if argv is None: argv = sys.argv[1:] # set up the basic logging configuration formatter = LogFormatter() # need two handlers, one that prints to stdout for the "run" command and one that prints to # stderr from the "generate" command; the latter is important because do not want the warning # to show up in the generated configuration file stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setFormatter(formatter) stderr_handler = logging.StreamHandler(sys.stderr) stderr_handler.setFormatter(formatter) logging.root.setLevel(logging.INFO) logger = logging.getLogger(__name__) # set up an argument parser via our helper function parser = setup_rsmcmd_parser("rsmexplain", uses_output_directory=True, allows_overwriting=True) # if no arguments provided then just show the help message if len(argv) < 1: argv.append("-h") # if the first argument is not one of the valid sub-commands # or one of the valid optional arguments, then assume that they # are arguments for the "run" sub-command. This allows the # old style command-line invocations to work without modification. if argv[0] not in VALID_PARSER_SUBCOMMANDS + [ "-h", "--help", "-V", "--version", ]: args_to_pass = ["run"] + argv else: args_to_pass = argv args = parser.parse_args(args=args_to_pass) if args.subcommand == "run": # when running, log to stdout logging.root.addHandler(stdout_handler) # run the experiment logger.info(f"Output directory: {args.output_dir}") generate_explanation( abspath(args.config_file), abspath(args.output_dir), overwrite_output=args.force_write, ) else: # when generating, log to stderr logging.root.addHandler(stderr_handler) # auto-generate an example configuration and print it to STDOUT generator = ConfigurationGenerator( "rsmexplain", as_string=True, suppress_warnings=args.quiet, use_subgroups=False, ) configuration = ( generator.interact(output_file_name=args.output_file.name if args.output_file else None) if args.interactive else generator.generate() ) print(configuration, file=args.output_file) if __name__ == "__main__": main()