Array API support

Overview

Many estimators from the Extension for Scikit-learn* support passing data classes that conform to the Array API specification as inputs to methods like .fit() and .predict(), such as dpnp.ndarray or torch.tensor. This is particularly useful for GPU computations, as it allows performing operations on inputs that are already on GPU without moving the data from host to device.

Important

Array API is disabled by default in scikit-learn. In order to get array API support in the Extension for Scikit-learn*, it must be enabled in scikit-learn, which requires either changing global settings or using a config_context, plus installing additional dependencies such as array-api-compat.

When passing array API inputs whose data is on a SYCL-enabled device (e.g. an Intel GPU), as supported for example by PyTorch and dpnp, if array API support is enabled and the requested operation (e.g. call to .fit() / .predict() on the estimator class being used) is supported on device/GPU, computations will be performed on the device where the data lives, without involving any data transfers. Note that all of the inputs (e.g. X and y passed to .fit() methods) must be allocated on the same device for this to work. If the requested operation is not supported on the device where the data lives, then it will either fall back to scikit-learn, or to an accelerated CPU version from the Extension for Scikit-learn* when supported - these are controllable through options allow_sklearn_after_onedal (default is True) and allow_fallback_to_host (default is False), respectively, which are accepted by config_context and set_config after patching scikit-learn or when importing those directly from sklearnex.

Note

Under default settings for set_config / config_context, operations that are not supported on GPU will fall back to scikit-learn instead of falling back to CPU versions from the Extension for Scikit-learn*.

If array API is enabled for scikit-learn and the estimator being used has array API support on scikit-learn (which can be verified by attribute array_api_support from sklearn.utils.get_tags), then array API inputs whose data is allocated neither on CPU nor on a SYCL device will be forwarded directly to the unpatched methods from scikit-learn, without using the accelerated versions from this library, regardless of option allow_sklearn_after_onedal.

While other array API inputs (e.g. torch arrays with data allocated on a non-SYCL device) might be supported by the Extension for Scikit-learn* in cases where the same class from scikit-learn doesn’t support array API, note that the data will be transferred to host if it isn’t already, and the computations will happen on CPU.

Hint

Enable Verbose Mode to see information about whether data transfers happen during an operation or not, whether an accelerated version from the extension is used, and where (CPU/device) the operation is executed.

When passing array API inputs to methods such as .predict() of estimators with array API support, the output will always be of the same class as the inputs, but be aware that array attributes of fitted models (e.g. coef_ in a linear model) will not necessarily be of the same class as array API inputs passed to .fit(), even though in many cases they are.

Warning

If array API inputs are passed to an estimator’s .fit(), subsequent data passed to methods such as .predict() or .score() of the fitted model must reside on the same device - meaning: a model that was fitted with GPU arrays cannot make predictions on CPU arrays, and a model fitted with CPU array API inputs cannot make predictions on GPU arrays, even if they are of the same class. Attempting to pass data on the wrong device might lead to process-wide crashes.

Note

An estimator fitted to array API inputs should only be passed objects of the same class that was passed to .fit() in subsequent calls to .predict(), .score(), and similar. In some cases, it might be possible to pass a different class at prediction time without errors (particularly when fitting on CPU only), but this is generally not supported and users should not rely on these interchanges working reliably.

Note

The target_offload option in config contexts and settings is not intended to work with array API classes that have USM data. In order to ensure that computations happen on the intended device under array API, make sure that the data is already on the desired device.

Supported classes

The following patched classes have support for array API inputs:

Note

While full array API support is currently not implemented for all classes, dpnp.ndarray inputs are supported by all the classes that have GPU support. Note however that if array API support is not enabled in scikit-learn, when passing these classes as inputs, data will be transferred to host and then back to device instead of being used directly.

Example usage

GPU operations on GPU arrays

# Array API support from sklearn requires enabling it on SciPy too
import os
os.environ["SCIPY_ARRAY_API"] = "1"

import numpy as np
import torch
from sklearnex import config_context
from sklearnex.linear_model import LinearRegression

# Random data for a regression problem
rng = np.random.default_rng(seed=123)
X_np = rng.standard_normal(size=(100, 10), dtype=np.float32)
y_np = rng.standard_normal(size=100, dtype=np.float32)

# Torch offers an array-API-compliant class where data can be on GPU (referred to as 'xpu')
X = torch.tensor(X_np, device="xpu")
y = torch.tensor(y_np, device="xpu")

# Important to note again that array API must be enabled on scikit-learn
model = LinearRegression()
with config_context(array_api_dispatch=True):
    model.fit(X, y)

# Fitted attributes are now of the same class as inputs
assert isinstance(model.coef_, torch.Tensor)

# Predictions are also of the same class
with config_context(array_api_dispatch=True):
    pred = model.predict(X[:5])
assert isinstance(pred, torch.Tensor)

array-api-strict

Example code showcasing how to use array-api-strict arrays to run patched sklearn.cluster.DBSCAN.

# ==============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import array_api_strict

from sklearnex import config_context, patch_sklearn

patch_sklearn()

from sklearn.cluster import DBSCAN

X = array_api_strict.asarray(
    [[1.0, 2.0], [2.0, 2.0], [2.0, 3.0], [8.0, 7.0], [8.0, 8.0], [25.0, 80.0]],
    dtype=array_api_strict.float32,
)

# Could be launched without `config_context(array_api_dispatch=True)`. This context
# manager for sklearnex, only guarantee that in case of the fallback to stock
# scikit-learn, fitted attributes to be from the same Array API namespace as
# the training data.
clustering = DBSCAN(eps=3, min_samples=2).fit(X)

print(f"Fitted labels :\n", clustering.labels_)