Source code for sklearnex.dispatcher

# ==============================================================================
# Copyright 2021 Intel Corporation
# Copyright 2024 Fujitsu Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os
import sys
from functools import lru_cache
from typing import Optional, Union

from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
from daal4py.sklearn.monkeypatch.dispatcher import PatchMap


def _is_new_patching_available():
    return os.environ.get("OFF_ONEDAL_IFACE", "0") == "0" and daal_check_version(
        (2021, "P", 300)
    )


def _is_preview_enabled() -> bool:
    return "SKLEARNEX_PREVIEW" in os.environ


# Comment 2026-01-20: This file has been refactored from the original
# implementation. Initially, the patching map dicts from daal4py and
# sklearnex were somehow meant to share some keys and to be used in
# place of one another under some circumstances, and it appears at
# some point the sklearnex one was meant to inherit things from the
# daal4py one, but that is not the expected behavior anymore. Initially,
# the code tried to accomplish object sharing by merging LRU caches
# of the functions that produce the maps, and there might still be
# traces of this behavior in the refactored code, but by now it is
# expected that the patching map dict objects from both should be
# independent and the functions should be passed the right object
# when needed.


# Note: the keys of this dict are only used as internal IDs to keep
# track of things, and in the functions to check if a given function
# or class is patched. The keys can be arbitrary strings that do not
# necessarily correspond to module paths, but having the full paths
# and names of what they patch makes them easier to identify and debug.
@lru_cache(maxsize=None)
def get_patch_map_core(preview: bool = False) -> PatchMap:
    if preview:
        mapping = get_patch_map_core(preview=False)

        if _is_new_patching_available():
            import sklearn.covariance as covariance_module
            import sklearn.decomposition as decomposition_module
            from sklearn.covariance import (
                EmpiricalCovariance as EmpiricalCovariance_sklearn,
            )
            from sklearn.decomposition import IncrementalPCA as IncrementalPCA_sklearn

            # Preview classes for patching
            from .preview.covariance import (
                EmpiricalCovariance as EmpiricalCovariance_sklearnex,
            )
            from .preview.decomposition import IncrementalPCA as IncrementalPCA_sklearnex

            # Since the state of the lru_cache without preview cannot be
            # guaranteed to not have already enabled sklearnex algorithms
            # when preview is used, setting the mapping element[1] to None
            # should NOT be done. This may lose track of the unpatched
            # sklearn estimator or function.
            # Covariance
            preview_mapping = {
                "sklearn.covariance.EmpiricalCovariance": (
                    covariance_module,
                    "EmpiricalCovariance",
                    EmpiricalCovariance_sklearnex,
                    EmpiricalCovariance_sklearn,
                ),
                "sklearn.decomposition.IncrementalPCA": (
                    decomposition_module,
                    "IncrementalPCA",
                    IncrementalPCA_sklearnex,
                    IncrementalPCA_sklearn,
                ),
            }
            if daal_check_version((2024, "P", 1)):
                import sklearn.linear_model as linear_model_module
                from sklearn.linear_model import (
                    LogisticRegressionCV as LogisticRegressionCV_sklearn,
                )

                from .preview.linear_model import (
                    LogisticRegressionCV as LogisticRegressionCV_sklearnex,
                )

                preview_mapping["sklearn.linear_model.LogisticRegressionCV"] = (
                    linear_model_module,
                    "LogisticRegressionCV",
                    LogisticRegressionCV_sklearnex,
                    LogisticRegressionCV_sklearn,
                )
            return mapping | preview_mapping

        return mapping

    # Comment 2026-01-20: This route is untested. It was meant to support
    # a situation in which the 'onedal' module is not compiled, and instead
    # the patching takes classes from daal4py, while still importing from
    # the sklearnex module. This is not tested in any kind of configurations.
    if not _is_new_patching_available():
        from daal4py.sklearn.monkeypatch.dispatcher import _get_map_of_algorithms

        return _get_map_of_algorithms()

    # Scikit-learn* modules
    import sklearn as base_module
    import sklearn.cluster as cluster_module
    import sklearn.covariance as covariance_module
    import sklearn.decomposition as decomposition_module
    import sklearn.dummy as dummy_module
    import sklearn.ensemble as ensemble_module

    if sklearn_check_version("1.4"):
        import sklearn.ensemble._gb as _gb_module
    else:
        import sklearn.ensemble._gb_losses as _gb_module
    import sklearn.linear_model as linear_model_module
    import sklearn.manifold as manifold_module
    import sklearn.metrics as metrics_module
    import sklearn.model_selection as model_selection_module
    import sklearn.neighbors as neighbors_module
    import sklearn.svm as svm_module

    if sklearn_check_version("1.2.1"):
        import sklearn.utils.parallel as parallel_module
    else:
        import sklearn.utils.fixes as parallel_module

    from sklearn.cluster import DBSCAN as DBSCAN_sklearn
    from sklearn.cluster import KMeans as KMeans_sklearn
    from sklearn.decomposition import PCA as PCA_sklearn
    from sklearn.dummy import DummyRegressor as DummyRegressor_sklearn
    from sklearn.ensemble import ExtraTreesClassifier as ExtraTreesClassifier_sklearn
    from sklearn.ensemble import ExtraTreesRegressor as ExtraTreesRegressor_sklearn
    from sklearn.ensemble import RandomForestClassifier as RandomForestClassifier_sklearn
    from sklearn.ensemble import RandomForestRegressor as RandomForestRegressor_sklearn
    from sklearn.linear_model import ElasticNet as ElasticNet_sklearn
    from sklearn.linear_model import Lasso as Lasso_sklearn
    from sklearn.linear_model import LinearRegression as LinearRegression_sklearn
    from sklearn.linear_model import LogisticRegression as LogisticRegression_sklearn
    from sklearn.linear_model import Ridge as Ridge_sklearn
    from sklearn.manifold import TSNE as TSNE_sklearn
    from sklearn.neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearn
    from sklearn.neighbors import KNeighborsRegressor as KNeighborsRegressor_sklearn
    from sklearn.neighbors import LocalOutlierFactor as LocalOutlierFactor_sklearn
    from sklearn.neighbors import NearestNeighbors as NearestNeighbors_sklearn
    from sklearn.svm import SVC as SVC_sklearn
    from sklearn.svm import SVR as SVR_sklearn
    from sklearn.svm import NuSVC as NuSVC_sklearn
    from sklearn.svm import NuSVR as NuSVR_sklearn

    if sklearn_check_version("1.4"):
        from sklearn.ensemble._gb import DummyRegressor as DummyRegressor_sklearn_gb
    else:
        from sklearn.ensemble._gb_losses import (
            DummyRegressor as DummyRegressor_sklearn_gb,
        )
    from sklearn import config_context as config_context_sklearn
    from sklearn import get_config as get_config_sklearn
    from sklearn import set_config as set_config_sklearn
    from sklearn.metrics import pairwise_distances as pairwise_distances_sklearn
    from sklearn.metrics import roc_auc_score as roc_auc_score_sklearn
    from sklearn.model_selection import train_test_split as train_test_split_sklearn

    if sklearn_check_version("1.2.1"):
        from sklearn.utils.parallel import _FuncWrapper as _FuncWrapper_sklearn
        from sklearn.utils.parallel import get_config as parallel_get_config_sklearn
    else:
        from sklearn.utils.fixes import _FuncWrapper as _FuncWrapper_sklearn
        from sklearn.utils.fixes import get_config as parallel_get_config_sklearn

    # Classes and functions for patching
    from ._config import config_context as config_context_sklearnex
    from ._config import get_config as get_config_sklearnex
    from ._config import set_config as set_config_sklearnex
    from .cluster import DBSCAN as DBSCAN_sklearnex
    from .cluster import KMeans as KMeans_sklearnex
    from .covariance import (
        IncrementalEmpiricalCovariance as IncrementalEmpiricalCovariance_sklearnex,
    )
    from .decomposition import PCA as PCA_sklearnex
    from .dummy import DummyRegressor as DummyRegressor_sklearnex
    from .ensemble import ExtraTreesClassifier as ExtraTreesClassifier_sklearnex
    from .ensemble import ExtraTreesRegressor as ExtraTreesRegressor_sklearnex
    from .ensemble import RandomForestClassifier as RandomForestClassifier_sklearnex
    from .ensemble import RandomForestRegressor as RandomForestRegressor_sklearnex
    from .linear_model import ElasticNet as ElasticNet_sklearnex
    from .linear_model import (
        IncrementalLinearRegression as IncrementalLinearRegression_sklearnex,
    )
    from .linear_model import IncrementalRidge as IncrementalRidge_sklearnex
    from .linear_model import Lasso as Lasso_sklearnex
    from .linear_model import LinearRegression as LinearRegression_sklearnex
    from .linear_model import LogisticRegression as LogisticRegression_sklearnex
    from .linear_model import Ridge as Ridge_sklearnex
    from .manifold import TSNE as TSNE_sklearnex
    from .metrics import pairwise_distances as pairwise_distances_sklearnex
    from .metrics import roc_auc_score as roc_auc_score_sklearnex
    from .model_selection import train_test_split as train_test_split_sklearnex
    from .neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearnex
    from .neighbors import KNeighborsRegressor as KNeighborsRegressor_sklearnex
    from .neighbors import LocalOutlierFactor as LocalOutlierFactor_sklearnex
    from .neighbors import NearestNeighbors as NearestNeighbors_sklearnex
    from .svm import SVC as SVC_sklearnex
    from .svm import SVR as SVR_sklearnex
    from .svm import NuSVC as NuSVC_sklearnex
    from .svm import NuSVR as NuSVR_sklearnex
    from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex

    mapping = {
        "sklearn.cluster.DBSCAN": (
            cluster_module,
            "DBSCAN",
            DBSCAN_sklearnex,
            DBSCAN_sklearn,
        ),
        "sklearn.cluster.KMeans": (
            cluster_module,
            "KMeans",
            KMeans_sklearnex,
            KMeans_sklearn,
        ),
        "sklearn.decomposition.PCA": (
            decomposition_module,
            "PCA",
            PCA_sklearnex,
            PCA_sklearn,
        ),
        "sklearn.svm.SVR": (svm_module, "SVR", SVR_sklearnex, SVR_sklearn),
        "sklearn.svm.SVC": (svm_module, "SVC", SVC_sklearnex, SVC_sklearn),
        "sklearn.svm.NuSVR": (svm_module, "NuSVR", NuSVR_sklearnex, NuSVR_sklearn),
        "sklearn.svm.NuSVC": (svm_module, "NuSVC", NuSVC_sklearnex, NuSVC_sklearn),
        "sklearn.linear_model.ElasticNet": (
            linear_model_module,
            "ElasticNet",
            ElasticNet_sklearnex,
            ElasticNet_sklearn,
        ),
        "sklearn.linear_model.Lasso": (
            linear_model_module,
            "Lasso",
            Lasso_sklearnex,
            Lasso_sklearn,
        ),
        "sklearn.linear_model.LinearRegression": (
            linear_model_module,
            "LinearRegression",
            LinearRegression_sklearnex,
            LinearRegression_sklearn,
        ),
        "sklearn.linear_model.LogisticRegression": (
            linear_model_module,
            "LogisticRegression",
            LogisticRegression_sklearnex,
            LogisticRegression_sklearn,
        ),
        "sklearn.linear_model.Ridge": (
            linear_model_module,
            "Ridge",
            Ridge_sklearnex,
            Ridge_sklearn,
        ),
        "sklearn.linear_model.IncrementalLinearRegression": (
            linear_model_module,
            "IncrementalLinearRegression",
            IncrementalLinearRegression_sklearnex,
            None,
        ),
        "sklearn.manifold.TSNE": (manifold_module, "TSNE", TSNE_sklearnex, TSNE_sklearn),
        "sklearn.metrics.pairwise_distances": (
            metrics_module,
            "pairwise_distances",
            pairwise_distances_sklearnex,
            pairwise_distances_sklearn,
        ),
        "sklearn.metrics.roc_auc_score": (
            metrics_module,
            "roc_auc_score",
            roc_auc_score_sklearnex,
            roc_auc_score_sklearn,
        ),
        "sklearn.model_selection.train_test_split": (
            model_selection_module,
            "train_test_split",
            train_test_split_sklearnex,
            train_test_split_sklearn,
        ),
        "sklearn.neighbors.KNeighborsClassifier": (
            neighbors_module,
            "KNeighborsClassifier",
            KNeighborsClassifier_sklearnex,
            KNeighborsClassifier_sklearn,
        ),
        "sklearn.neighbors.KNeighborsRegressor": (
            neighbors_module,
            "KNeighborsRegressor",
            KNeighborsRegressor_sklearnex,
            KNeighborsRegressor_sklearn,
        ),
        "sklearn.neighbors.NearestNeighbors": (
            neighbors_module,
            "NearestNeighbors",
            NearestNeighbors_sklearnex,
            NearestNeighbors_sklearn,
        ),
        "sklearn.neighbors.LocalOutlierFactor": (
            neighbors_module,
            "LocalOutlierFactor",
            LocalOutlierFactor_sklearnex,
            LocalOutlierFactor_sklearn,
        ),
        "sklearn.ensemble.ExtraTreesClassifier": (
            ensemble_module,
            "ExtraTreesClassifier",
            ExtraTreesClassifier_sklearnex,
            ExtraTreesClassifier_sklearn,
        ),
        "sklearn.ensemble.ExtraTreesRegressor": (
            ensemble_module,
            "ExtraTreesRegressor",
            ExtraTreesRegressor_sklearnex,
            ExtraTreesRegressor_sklearn,
        ),
        "sklearn.ensemble.RandomForestClassifier": (
            ensemble_module,
            "RandomForestClassifier",
            RandomForestClassifier_sklearnex,
            RandomForestClassifier_sklearn,
        ),
        "sklearn.ensemble.RandomForestRegressor": (
            ensemble_module,
            "RandomForestRegressor",
            RandomForestRegressor_sklearnex,
            RandomForestRegressor_sklearn,
        ),
        "sklearn.covariance.IncrementalEmpiricalCovariance": (
            covariance_module,
            "IncrementalEmpiricalCovariance",
            IncrementalEmpiricalCovariance_sklearnex,
            None,
        ),
        "sklearn.dummy.DummyRegressor": (
            dummy_module,
            "DummyRegressor",
            DummyRegressor_sklearnex,
            DummyRegressor_sklearn,
        ),
        "sklearn.ensemble._gb_losses.DummyRegressor": (
            _gb_module,
            "DummyRegressor",
            DummyRegressor_sklearnex,
            DummyRegressor_sklearn_gb,
        ),
        # These should be patched even if it applying to a single algorithm
        "sklearn.set_config": (
            base_module,
            "set_config",
            set_config_sklearnex,
            set_config_sklearn,
        ),
        "sklearn.get_config": (
            base_module,
            "get_config",
            get_config_sklearnex,
            get_config_sklearn,
        ),
        "sklearn.config_context": (
            base_module,
            "config_context",
            config_context_sklearnex,
            config_context_sklearn,
        ),
        # Comment 2026-01-20: The comment below was present in earlier code.
        # Whether it's true that is needed or not hasn't been verified.
        # --- end of comment 2026-01-20 ----
        # Necessary for proper work with multiple threads
        "sklearn.utils.parallel.get_config": (
            parallel_module,
            "get_config",
            get_config_sklearnex,
            parallel_get_config_sklearn,
        ),
        "sklearn.utils.parallel._funcwrapper": (
            parallel_module,
            "_FuncWrapper",
            _FuncWrapper_sklearnex,
            _FuncWrapper_sklearn,
        ),
    }

    if daal_check_version((2024, "P", 600)):
        mapping["sklearn.linear_model.IncrementalRidge"] = (
            linear_model_module,
            "IncrementalRidge",
            IncrementalRidge_sklearnex,
            None,
        )

    return mapping


# This is necessary to properly cache the patch_map when
# using preview.
def get_patch_map() -> PatchMap:
    preview: bool = _is_preview_enabled()
    return get_patch_map_core(preview=preview)


def get_patch_names() -> list[str]:
    return list(get_patch_map().keys())


[docs] def patch_sklearn( name: Optional[Union[str, list[str]]] = None, verbose: bool = True, global_patch: bool = False, preview: bool = False, ) -> None: """Apply patching to the ``sklearn`` module. Patches the ``sklearn`` module from |sklearn| to make calls to the accelerated versions of estimators and functions from the |sklearnex|, either as a whole or on a per-estimator basis. Notes ----- If estimators from ``sklearn`` have already been imported before ``patch_sklearn`` is called, they need to be re-imported in order for the patching to take effect. See Also -------- is_patched_instance: To verify that an instance of an estimator is patched. unpatch_sklearn: To undo the patching. Parameters ---------- name : str, list of str, or None Names of the desired estimators to patch. Can pass a single instance name (e.g. ``"sklearn.linear_model.LogisticRegression"``), or a list of names (e.g. ``["sklearn.linear_model.LogisticRegression", "sklearn.decomposition.PCA"]``). If ``None``, will patch all the supported estimators. See the :doc:`algorithm support table <algorithms>` for more information. Note that functions related to :doc:`config contexts <config-contexts>` are always patched regardless of what's passed here. verbose : bool Whether to print information messages about the patching being applied or not. Note that this refers only to a message about patching applied through this function. Passing ``True`` here does **not** enable :doc:`verbose mode <verbose>` for further estimator calls. When the message is printed, it will use the Python ``stderr`` stream. global_patch : bool Whether to apply the patching on the installed ``sklearn`` module itself, which is a mechanism that persists across sessions and processes. If ``True``, the ``sklearn`` module files will be modified to apply patching immediately upon import of this module, so that next time, importing of ``sklearnex`` will not be necessary. preview : bool Whether to include the :doc:`preview estimators <preview>` in the patching. Note that this will forcibly set the environment variable ``SKLEARNEX_PREVIEW``. If environment variable ``SKLEARNEX_PREVIEW`` is set at the moment this function is called, preview estimators will be patched regardless. Examples -------- >>> from sklearnex import is_patched_instance >>> from sklearnex import patch_sklearn >>> from sklearn.linear_model import LinearRegression >>> is_patched_instance(LinearRegression()) False >>> patch_sklearn() >>> from sklearn.linear_model import LinearRegression # now calls sklearnex >>> is_patched_instance(LinearRegression()) True""" if preview: os.environ["SKLEARNEX_PREVIEW"] = "enabled_via_patch_sklearn" if not sklearn_check_version("1.0"): raise NotImplementedError( "Extension for Scikit-learn* patches apply " "for scikit-learn >= 1.0 only ..." ) if global_patch: from sklearnex.glob.dispatcher import patch_sklearn_global patch_sklearn_global(name, verbose) from daal4py.sklearn import patch_sklearn as patch_sklearn_orig patch_map: PatchMap = get_patch_map() if name is not None and _is_new_patching_available(): names_mandatory = [ "sklearn.set_config", "sklearn.get_config", "sklearn.config_context", "sklearn.utils.parallel.get_config", "sklearn.utils.parallel._funcwrapper", ] for name_mandatory in names_mandatory: patch_sklearn_orig( name_mandatory, verbose=False, deprecation=False, map=patch_map ) if isinstance(name, list): for algorithm in name: patch_sklearn_orig(algorithm, verbose=False, deprecation=False, map=patch_map) else: patch_sklearn_orig(name, verbose=False, deprecation=False, map=patch_map) if verbose and sys.stderr is not None: sys.stderr.write( "Extension for Scikit-learn* enabled " "(https://github.com/uxlfoundation/scikit-learn-intelex)\n" )
[docs] def unpatch_sklearn( name: Optional[Union[str, list[str]]] = None, global_unpatch: bool = False ) -> None: """Unpatch scikit-learn. Unpatches the ``sklearn`` module, either as a whole or for selected estimators. .. Note If preview mode was enabled through ``patch_sklearn(preview=True)``, it will modify the environment variable ``SKLEARNEX_PREVIEW``, by deleting it. Parameters ---------- name : str, list of str, or None Names of the desired estimators to check for patching status. Can pass a single instance name (e.g. ``"sklearn.linear_model.LogisticRegression"``), or a list of names (e.g. ``["sklearn.linear_model.LogisticRegression", "sklearn.decomposition.PCA"]``). If ``None``, will unpatch all the etimators that are patched. global_unpatch : bool Whether to unpatch the installed ``sklearn`` module itself, if patching had been applied to it (see :obj:`patch_sklearn`).""" if global_unpatch: from sklearnex.glob.dispatcher import unpatch_sklearn_global unpatch_sklearn_global() from daal4py.sklearn import unpatch_sklearn as unpatch_sklearn_orig patch_map: PatchMap = get_patch_map() if isinstance(name, list): for algorithm in name: unpatch_sklearn_orig(algorithm, map=patch_map) else: unpatch_sklearn_orig(name, map=patch_map) if os.environ.get("SKLEARNEX_PREVIEW") == "enabled_via_patch_sklearn": os.environ.pop("SKLEARNEX_PREVIEW")
[docs] def sklearn_is_patched( name: Optional[Union[str, list[str]]] = None, return_map: Optional[bool] = False ) -> Union[bool, dict[str, bool]]: """Check patching status. Checks whether patching of |sklearn| estimators has been applied, either as a whole or for a subset of estimators. Parameters ---------- name : str, list of str, or None Names of the desired estimators to check for patching status. Can pass a single instance name (e.g. ``"LogisticRegression"``), or a list of names (e.g. ``["LogisticRegression", "PCA"]``). If ``None``, will check for patching status of all estimators. return_map : bool Whether to return per-estimator patching statuses, or just a single result, which will be ``True`` if all the estimators from ``name`` are patched. Returns ------- Check : bool or dict[str, bool] The patching status of the desired estimators, either as a whole, or on a per-estimator basis (output type controlled by ``return_map``).""" from daal4py.sklearn import sklearn_is_patched as sklearn_is_patched_orig map = get_patch_map() if isinstance(name, list): if return_map: result: dict[str, bool] = {} for algorithm in name: result[algorithm] = sklearn_is_patched_orig(algorithm, map=map) return result else: is_patched = True for algorithm in name: is_patched = is_patched and sklearn_is_patched_orig(algorithm, map=map) return is_patched else: return sklearn_is_patched_orig(name, return_map=return_map, map=map)
[docs] def is_patched_instance(instance: object) -> bool: """Check if given estimator instance is patched with scikit-learn-intelex. Parameters ---------- instance : object Python object, usually a scikit-learn estimator instance. Returns ------- Check : bool Boolean whether instance is a daal4py or sklearnex estimator. Examples -------- >>> from sklearnex import is_patched_instance >>> from sklearn.linear_model import LinearRegression >>> from sklearnex.linear_model import LinearRegression as patched_LR >>> is_patched_instance(LinearRegression()) False >>> is_patched_instance(patched_LR()) True""" module = getattr(instance, "__module__", "") return ("daal4py" in module) or ("sklearnex" in module)