Source code for sklearnex._config

# ==============================================================================
# Copyright 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import sys
from contextlib import contextmanager

from sklearn import get_config as skl_get_config
from sklearn import set_config as skl_set_config

from daal4py.sklearn._utils import sklearn_check_version
from onedal._config import _get_config as onedal_get_config

__all__ = ["get_config", "set_config", "config_context"]

tab = "    " if (sys.version_info.major == 3 and sys.version_info.minor < 13) else ""
_options_docstring = f"""Parameters
{tab}----------
{tab}target_offload : str or dpctl.SyclQueue or None
{tab}    The device used to perform computations, either as a string indicating a name
{tab}    recognized by the SyCL runtime, such as ``"gpu"``, ``"gpu:0"``, or as a
{tab}    :obj:`dpctl.SyclQueue` object indicating where to move the data.
{tab}
{tab}    Assuming SyCL-related dependencies are installed, the list of devices recognized
{tab}    by SyCL can be retrieved through the CLI tool ``sycl-ls`` in a shell, or through
{tab}    :obj:`dpctl.get_devices` in a Python process.
{tab}
{tab}    String ``"auto"`` is also accepted.
{tab}
{tab}    Global default: ``"auto"``.
{tab}
{tab}allow_fallback_to_host : bool or None
{tab}    If ``True``, allows computations to fall back to host device (CPU) when an unsupported
{tab}    operation is attempted on GPU through ``target_offload``.
{tab}
{tab}    Global default: ``False``.
{tab}
{tab}allow_sklearn_after_onedal : bool or None, default=None
{tab}    If ``True``, allows computations to fall back to stock scikit-learn when no
{tab}    accelered version of the operation is available (see :ref:`algorithms`).
{tab}
{tab}    Global default: ``True``.
{tab}
{tab}use_raw_input : bool or None
{tab}    If ``True``, uses the raw input data in some SPMD onedal backend computations
{tab}    without any checks on data consistency or validity. Note that this can be
{tab}    better achieved through usage of :ref:`array API classes <array_api>` without
{tab}    ``target_offload``. Not recommended for general use.
{tab}
{tab}    Global default: ``False``.
{tab}
{tab}    .. deprecated:: 2026.0
{tab}
{tab}sklearn_configs : kwargs
{tab}    Other settings accepted by scikit-learn. See :obj:`sklearn.set_config` for
{tab}    details.
{tab}
{tab}Warnings
{tab}--------
{tab}Using ``use_raw_input=True`` is not recommended for general use as it
{tab}bypasses data consistency checks, which may lead to unexpected behavior. It is
{tab}recommended to use the newer :ref:`array API <array_api>` instead.
{tab}
{tab}Note
{tab}----
{tab}Usage of ``target_offload`` requires additional dependencies - see
{tab}:ref:`GPU support <oneapi_gpu>` for more information."""


[docs] def get_config(): """Retrieve current values for configuration set by :func:`set_config`. Returns ------- config : dict Keys are parameter names that can be passed to :func:`set_config`. See Also -------- config_context : Context manager for global configuration. set_config : Set global configuration. """ sklearn = skl_get_config() sklearnex = onedal_get_config() return {**sklearn, **sklearnex}
[docs] def set_config( target_offload=None, allow_fallback_to_host=None, allow_sklearn_after_onedal=None, use_raw_input=None, **sklearn_configs, ): # numpydoc ignore=PR01,PR07 """Set global configuration. %_options_docstring% See Also -------- config_context : Context manager for global configuration. get_config : Retrieve current values of the global configuration. """ skl_set_config(**sklearn_configs) local_config = onedal_get_config(copy=False) if target_offload is not None: local_config["target_offload"] = target_offload if allow_fallback_to_host is not None: local_config["allow_fallback_to_host"] = allow_fallback_to_host if allow_sklearn_after_onedal is not None: local_config["allow_sklearn_after_onedal"] = allow_sklearn_after_onedal if use_raw_input is not None: local_config["use_raw_input"] = use_raw_input
set_config.__doc__ = set_config.__doc__.replace( "%_options_docstring%", _options_docstring )
[docs] @contextmanager def config_context(**new_config): # numpydoc ignore=PR01,PR07 """Context manager for local scikit-learn-intelex configurations. %_options_docstring% Note ---- All settings, not just those presently modified, will be returned to their previous values when the context manager is exited. See Also -------- set_config : Set global scikit-learn configuration. get_config : Retrieve current values of the global configuration. """ old_config = get_config() set_config(**new_config) try: yield finally: set_config(**old_config)
config_context.__doc__ = config_context.__doc__.replace( "%_options_docstring%", _options_docstring )