Source code for rsmtool.rsmcompare

#!/usr/bin/env python
"""
Compare two rsmtool/rsmeval experiments.

:author: Jeremy Biggs (jbiggs@ets.org)
:author: Anastassia Loukina (aloukina@ets.org)
:author: Nitin Madnani (nmadnani@ets.org)

:organization: ETS
"""

import glob
import logging
import sys
from os.path import abspath, exists, join, normpath
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from .configuration_parser import Configuration, configure
from .reader import DataReader
from .reporter import Reporter
from .utils.commandline import ConfigurationGenerator, setup_rsmcmd_parser
from .utils.constants import VALID_PARSER_SUBCOMMANDS
from .utils.logging import LogFormatter
from .utils.wandb import init_wandb_run, log_configuration_to_wandb


def check_experiment_id(experiment_dir: str, experiment_id: str) -> None:
    """
    Check that ``experiment_dir`` contains the outputs for ``experiment_id``.

    Parameters
    ----------
    experiment_dir : str
        path to the directory with the experiment output.
    experiment_id : str
        The ID of the original experiment used to generate the output.

    Raises
    ------
    FileNotFoundError
        If ``experiment_dir`` does not contain any outputs for ``experiment_id``.
    """
    # list all possible output files which start with experiment_id
    outputs = glob.glob(join(experiment_dir, "output", f"{experiment_id}_*.*"))

    # raise an error if none exists
    if len(outputs) == 0:
        raise FileNotFoundError(
            f"The directory {experiment_dir} does not contain "
            f"any outputs of an rsmtool experiment {experiment_id}"
        )


[docs] def run_comparison( config_file_or_obj_or_dict: Union[str, Configuration, Dict[str, Any], Path], output_dir: str ) -> None: """ Run an rsmcompare experiment using the given configuration. Use the given configuration file, object, or dictionary and generate the report in the given directory. Parameters ---------- config_file_or_obj_or_dict : Union[str, Configuration, Dict[str, Any], Path] Path to the experiment configuration file either a a string or as a ``pathlib.Path`` object. Users can also pass a ``Configuration`` object that is in memory or a Python dictionary with keys corresponding to fields in the configuration file. Given a configuration file, any relative paths in the configuration file will be interpreted relative to the location of the file. Given a ``Configuration`` object, relative paths will be interpreted relative to the ``configdir`` attribute, that _must_ be set. Given a dictionary, the reference path is set to the current directory. output_dir : str Path to the experiment output directory. Raises ------ FileNotFoundError If either of the two input directories in ``config_file_or_obj_or_dict`` do not exist. FileNotFoundError If the directories do not contain rsmtool outputs at all. """ logger = logging.getLogger(__name__) configuration = configure("rsmcompare", config_file_or_obj_or_dict) logger.info("Saving configuration file.") configuration.save(output_dir) # If wandb logging is enabled, and wandb_run is not provided, # start a wandb run and log configuration wandb_run = init_wandb_run(configuration) log_configuration_to_wandb(wandb_run, configuration) # get the information about the "old" experiment experiment_id_old = configuration["experiment_id_old"] experiment_dir_old = DataReader.locate_files( configuration["experiment_dir_old"], configuration.configdir )[0] if not experiment_dir_old: raise FileNotFoundError( f"The directory {configuration['experiment_dir_old']} " f"does not exist." ) csvdir_old = normpath(join(experiment_dir_old, "output")) figdir_old = normpath(join(experiment_dir_old, "figure")) if not exists(csvdir_old) or not exists(figdir_old): raise FileNotFoundError( f"The directory {experiment_dir_old} does not contain " f"the output of an rsmtool experiment." ) check_experiment_id(experiment_dir_old, experiment_id_old) # get the information about the "new" experiment experiment_id_new = configuration["experiment_id_new"] experiment_dir_new = DataReader.locate_files( configuration["experiment_dir_new"], configuration.configdir )[0] if not experiment_dir_new: raise FileNotFoundError( f"The directory {configuration['experiment_dir_new']} " f"does not exist." ) csvdir_new = normpath(join(experiment_dir_new, "output")) figdir_new = normpath(join(experiment_dir_new, "figure")) if not exists(csvdir_new) or not exists(figdir_new): raise FileNotFoundError( f"The directory {experiment_dir_new} does not contain the " f"output of an rsmtool experiment." ) check_experiment_id(experiment_dir_new, experiment_id_new) # are there specific general report sections we want to include? general_report_sections = configuration["general_sections"] custom_report_section_paths = configuration["custom_sections"] # if custom report sections exist, locate sections; otherwise, create empty list if custom_report_section_paths: logger.info("Locating custom report sections") custom_report_sections = Reporter.locate_custom_sections( custom_report_section_paths, configuration.configdir ) else: custom_report_sections = [] # get the section order section_order = configuration["section_order"] # get the subgroups if any subgroups = configuration.get("subgroups") # Initialize reporter reporter = Reporter(logger=logger, wandb_run=wandb_run) chosen_notebook_files = reporter.get_ordered_notebook_files( general_report_sections, custom_report_sections, section_order, subgroups, model_type=None, context="rsmcompare", ) # add chosen notebook files to configuration configuration["chosen_notebook_files"] = chosen_notebook_files # now generate the comparison report logger.info("Starting report generation.") reporter.create_comparison_report( configuration, csvdir_old, figdir_old, csvdir_new, figdir_new, output_dir )
def main(argv: Optional[List[str]] = None) -> None: """ Entry point for the ``rsmcompare`` command-line tool. Parameters ---------- argv : Optional[List[str]] List of arguments to use instead of ``sys.argv``. Defaults to ``None``. """ # if no arguments are passed, then use sys.argv if argv is None: argv = sys.argv[1:] # set up the basic logging configuration formatter = LogFormatter() # we need two handlers, one that prints to stdout # for the "run" command and one that prints to stderr # from the "generate" command; the latter is necessary # because do not want the warning to show up in the # generated configuration file stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setFormatter(formatter) stderr_handler = logging.StreamHandler(sys.stderr) stderr_handler.setFormatter(formatter) logging.root.setLevel(logging.INFO) logger = logging.getLogger(__name__) # set up an argument parser via our helper function parser = setup_rsmcmd_parser("rsmcompare", uses_output_directory=True, uses_subgroups=True) # if we have no arguments at all then just show the help message if len(argv) < 1: argv.append("-h") # if the first argument is not one of the valid sub-commands # or one of the valid optional arguments, then assume that they # are arguments for the "run" sub-command. This allows the # old style command-line invocations to work without modification. if argv[0] not in VALID_PARSER_SUBCOMMANDS + [ "-h", "--help", "-V", "--version", ]: args_to_pass = ["run"] + argv else: args_to_pass = argv args = parser.parse_args(args=args_to_pass) # call the appropriate function based on which sub-command was run if args.subcommand == "run": # when running, log to stdout logging.root.addHandler(stdout_handler) # run the experiment logger.info(f"Output directory: {args.output_dir}") run_comparison(abspath(args.config_file), abspath(args.output_dir)) else: # when generating, log to stderr logging.root.addHandler(stderr_handler) # auto-generate an example configuration and print it to STDOUT generator = ConfigurationGenerator( "rsmcompare", as_string=True, suppress_warnings=args.quiet, use_subgroups=args.subgroups, ) configuration = ( generator.interact(output_file_name=args.output_file.name if args.output_file else None) if args.interactive else generator.generate() ) print(configuration, file=args.output_file) if __name__ == "__main__": main()