Array API support
Overview
Many estimators from the Extension for Scikit-learn* support passing data classes that conform to the
Array API specification as inputs to methods like .fit()
and .predict(), such as dpnp.ndarray or torch.tensor.
This is particularly useful for GPU computations, as it allows performing operations on inputs that are already
on GPU without moving the data from host to device.
Important
Array API is disabled by default in scikit-learn. In order to get array API support in the Extension for Scikit-learn*, it must
be enabled in scikit-learn, which requires either changing
global settings or using a config_context, plus installing additional dependencies such as array-api-compat.
When passing array API inputs whose data is on a SYCL-enabled device (e.g. an Intel GPU), as
supported for example by PyTorch
and dpnp, if array API support is enabled and the requested operation (e.g. call to .fit() / .predict()
on the estimator class being used) is supported on device/GPU, computations
will be performed on the device where the data lives, without involving any data transfers. Note that all of
the inputs (e.g. X and y passed to .fit() methods) must be allocated on the same device for this to
work. If the requested operation is not supported on the device where the data lives, then it will either fall
back to scikit-learn, or to an accelerated CPU version from the Extension for Scikit-learn* when supported - these are controllable
through options allow_sklearn_after_onedal (default is True) and allow_fallback_to_host (default is
False), respectively, which are accepted by config_context and set_config after
patching scikit-learn or when importing those directly from sklearnex.
Note
Under default settings for set_config / config_context, operations that are not supported on GPU will
fall back to scikit-learn instead of falling back to CPU versions from the Extension for Scikit-learn*.
If array API is enabled for scikit-learn and the estimator being used has array API support on scikit-learn (which can be
verified by attribute array_api_support from sklearn.utils.get_tags), then array API inputs whose data
is allocated neither on CPU nor on a SYCL device will be forwarded directly to the unpatched methods from scikit-learn,
without using the accelerated versions from this library, regardless of option allow_sklearn_after_onedal.
While other array API inputs (e.g. torch arrays with data allocated on a non-SYCL device) might be supported by the Extension for Scikit-learn* in cases where the same class from scikit-learn doesn’t support array API, note that the data will be transferred to host if it isn’t already, and the computations will happen on CPU.
Hint
Enable Verbose Mode to see information about whether data transfers happen during an operation or not, whether an accelerated version from the extension is used, and where (CPU/device) the operation is executed.
When passing array API inputs to methods such as .predict() of estimators with array API support, the output
will always be of the same class as the inputs, but be aware that array attributes of fitted models (e.g. coef_
in a linear model) will not necessarily be of the same class as array API inputs passed to .fit(), even though
in many cases they are.
Warning
If array API inputs are passed to an estimator’s .fit(), subsequent data passed to methods such as
.predict() or .score() of the fitted model must reside on the same device - meaning: a model that
was fitted with GPU arrays cannot make predictions on CPU arrays, and a model fitted with CPU array API inputs
cannot make predictions on GPU arrays, even if they are of the same class. Attempting to pass data on the
wrong device might lead to process-wide crashes.
Note
An estimator fitted to array API inputs should only be passed objects of the same class that was passed to
.fit() in subsequent calls to .predict(), .score(), and similar. In some cases, it might be
possible to pass a different class at prediction time without errors (particularly when fitting on CPU only),
but this is generally not supported and users should not rely on these interchanges working reliably.
Note
The target_offload option in config contexts and settings is not intended to work with array API
classes that have USM data. In order to ensure that computations
happen on the intended device under array API, make sure that the data is already on the desired device.
Supported classes
The following patched classes have support for array API inputs:
Note
While full array API support is currently not implemented for all classes, dpnp.ndarray inputs are supported by all the classes that have GPU support. Note however that if array API support is not enabled in scikit-learn, when passing these classes as inputs, data will be transferred to host and then back to device instead of being used directly.
Example usage
GPU operations on GPU arrays
# Array API support from sklearn requires enabling it on SciPy too
import os
os.environ["SCIPY_ARRAY_API"] = "1"
import numpy as np
import torch
from sklearnex import config_context
from sklearnex.linear_model import LinearRegression
# Random data for a regression problem
rng = np.random.default_rng(seed=123)
X_np = rng.standard_normal(size=(100, 10), dtype=np.float32)
y_np = rng.standard_normal(size=100, dtype=np.float32)
# Torch offers an array-API-compliant class where data can be on GPU (referred to as 'xpu')
X = torch.tensor(X_np, device="xpu")
y = torch.tensor(y_np, device="xpu")
# Important to note again that array API must be enabled on scikit-learn
model = LinearRegression()
with config_context(array_api_dispatch=True):
model.fit(X, y)
# Fitted attributes are now of the same class as inputs
assert isinstance(model.coef_, torch.Tensor)
# Predictions are also of the same class
with config_context(array_api_dispatch=True):
pred = model.predict(X[:5])
assert isinstance(pred, torch.Tensor)
# Array API support from sklearn requires enabling it on SciPy too
import os
os.environ["SCIPY_ARRAY_API"] = "1"
import numpy as np
import dpnp
from sklearnex import config_context
from sklearnex.linear_model import LinearRegression
# Random data for a regression problem
rng = np.random.default_rng(seed=123)
X_np = rng.standard_normal(size=(100, 10), dtype=np.float32)
y_np = rng.standard_normal(size=100, dtype=np.float32)
# DPNP offers an array-API-compliant class where data can be on GPU
X = dpnp.array(X_np, device="gpu")
y = dpnp.array(y_np, device="gpu")
# Important to note again that array API must be enabled on scikit-learn
model = LinearRegression()
with config_context(array_api_dispatch=True):
model.fit(X, y)
# Fitted attributes are now of the same class as inputs
assert isinstance(model.coef_, X.__class__)
# Predictions are also of the same class
with config_context(array_api_dispatch=True):
pred = model.predict(X[:5])
assert isinstance(pred, X.__class__)
array-api-strict
Example code showcasing how to use array-api-strict
arrays to run patched sklearn.cluster.DBSCAN.
# ==============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import array_api_strict
from sklearnex import config_context, patch_sklearn
patch_sklearn()
from sklearn.cluster import DBSCAN
X = array_api_strict.asarray(
[[1.0, 2.0], [2.0, 2.0], [2.0, 3.0], [8.0, 7.0], [8.0, 8.0], [25.0, 80.0]],
dtype=array_api_strict.float32,
)
# Could be launched without `config_context(array_api_dispatch=True)`. This context
# manager for sklearnex, only guarantee that in case of the fallback to stock
# scikit-learn, fitted attributes to be from the same Array API namespace as
# the training data.
clustering = DBSCAN(eps=3, min_samples=2).fit(X)
print(f"Fitted labels :\n", clustering.labels_)