Source code for sklearnex.linear_model.incremental_linear
# ===============================================================================# Copyright 2024 Intel Corporation## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# ===============================================================================importnumbersimportwarningsimportnumpyasnpfromsklearn.baseimportBaseEstimator,MultiOutputMixin,RegressorMixinfromsklearn.metricsimportr2_scorefromsklearn.utilsimportcheck_array,gen_batchesfromsklearn.utils.validationimportcheck_is_fittedfromdaal4py.sklearn._n_jobs_supportimportcontrol_n_jobsfromdaal4py.sklearn._utilsimportdaal_check_version,sklearn_check_versionfromonedal.linear_modelimport(IncrementalLinearRegressionasonedal_IncrementalLinearRegression,)ifsklearn_check_version("1.2"):fromsklearn.utils._param_validationimportIntervalifsklearn_check_version("1.6"):fromsklearn.utils.validationimportvalidate_dataelse:validate_data=BaseEstimator._validate_datafromonedal.common.hyperparametersimportget_hyperparametersfrom.._device_offloadimportdispatch,wrap_output_datafrom.._utilsimportIntelEstimator,PatchingConditionsChain,register_hyperparameters
[docs]@register_hyperparameters({"fit":get_hyperparameters("linear_regression","train"),"partial_fit":get_hyperparameters("linear_regression","train"),})@control_n_jobs(decorated_methods=["fit","partial_fit","predict","score","_onedal_finalize_fit"])classIncrementalLinearRegression(IntelEstimator,MultiOutputMixin,RegressorMixin,BaseEstimator):""" Trains a linear regression model, allows for computation if the data are split into batches. The user can use the ``partial_fit`` method to provide a single batch of data or use the ``fit`` method to provide the entire dataset. Parameters ---------- fit_intercept : bool, default=True Whether to calculate the intercept for this model. If set to False, no intercept will be used in calculations (i.e. data is expected to be centered). copy_X : bool, default=True If True, X will be copied; else, it may be overwritten. n_jobs : int, default=None The number of jobs to use for the computation. batch_size : int, default=None The number of samples to use for each batch. Only used when calling ``fit``. If ``batch_size`` is ``None``, then ``batch_size`` is inferred from the data and set to ``5 * n_features``. Attributes ---------- coef_ : array of shape (n_features, ) or (n_targets, n_features) Estimated coefficients for the linear regression problem. If multiple targets are passed during the fit (y 2D), this is a 2D array of shape (n_targets, n_features), while if only one target is passed, this is a 1D array of length n_features. intercept_ : float or array of shape (n_targets,) Independent term in the linear model. Set to 0.0 if `fit_intercept = False`. n_samples_seen_ : int The number of samples processed by the estimator. Will be reset on new calls to ``fit``, but increments across ``partial_fit`` calls. It should be not less than `n_features_in_` if `fit_intercept` is False and not less than `n_features_in_` + 1 if `fit_intercept` is True to obtain regression coefficients. batch_size_ : int Inferred batch size from ``batch_size``. n_features_in_ : int Number of features seen during ``fit`` or ``partial_fit``. Note ---- Serializing instances of this class will trigger a forced finalization of calculations. Since finalize_fit can't be dispatched without directly provided queue and the dispatching policy can't be serialized, the computation is finalized during serialization call and the policy is not saved in serialized data. Examples -------- >>> import numpy as np >>> from sklearnex.linear_model import IncrementalLinearRegression >>> inclr = IncrementalLinearRegression(batch_size=2) >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 10]]) >>> y = np.array([1.5, 3.5, 5.5, 8.5]) >>> inclr.partial_fit(X[:2], y[:2]) >>> inclr.partial_fit(X[2:], y[2:]) >>> inclr.coef_ np.array([0.5., 0.5.]) >>> inclr.intercept_ np.array(0.) >>> inclr.fit(X) >>> inclr.coef_ np.array([0.5., 0.5.]) >>> inclr.intercept_ np.array(0.) """_onedal_incremental_linear=staticmethod(onedal_IncrementalLinearRegression)ifsklearn_check_version("1.2"):_parameter_constraints:dict={"fit_intercept":["boolean"],"copy_X":["boolean"],"n_jobs":[Interval(numbers.Integral,-1,None,closed="left"),None],"batch_size":[Interval(numbers.Integral,1,None,closed="left"),None],}def__init__(self,*,fit_intercept=True,copy_X=True,n_jobs=None,batch_size=None):self.fit_intercept=fit_interceptself.copy_X=copy_Xself.n_jobs=n_jobsself.batch_size=batch_sizedef_onedal_supported(self,method_name,*data):patching_status=PatchingConditionsChain(f"sklearn.linear_model.{self.__class__.__name__}.{method_name}")returnpatching_status_onedal_cpu_supported=_onedal_supported_onedal_gpu_supported=_onedal_supporteddef_onedal_predict(self,X,queue=None):ifsklearn_check_version("1.2"):self._validate_params()ifsklearn_check_version("1.0"):X=validate_data(self,X,dtype=[np.float64,np.float32],copy=self.copy_X,reset=False,)else:X=check_array(X,dtype=[np.float64,np.float32],copy=self.copy_X,)asserthasattr(self,"_onedal_estimator")ifself._need_to_finalize:self._onedal_finalize_fit()returnself._onedal_estimator.predict(X,queue=queue)def_onedal_score(self,X,y,sample_weight=None,queue=None):returnr2_score(y,self._onedal_predict(X,queue=queue),sample_weight=sample_weight)def_onedal_partial_fit(self,X,y,check_input=True,queue=None):first_pass=nothasattr(self,"n_samples_seen_")orself.n_samples_seen_==0ifsklearn_check_version("1.2"):self._validate_params()ifcheck_input:ifsklearn_check_version("1.0"):X,y=validate_data(self,X,y,dtype=[np.float64,np.float32],reset=first_pass,copy=self.copy_X,multi_output=True,force_all_finite=False,)else:X=check_array(X,dtype=[np.float64,np.float32],copy=self.copy_X,force_all_finite=False,)y=check_array(y,dtype=[np.float64,np.float32],copy=False,ensure_2d=False,force_all_finite=False,)iffirst_pass:self.n_samples_seen_=X.shape[0]self.n_features_in_=X.shape[1]else:self.n_samples_seen_+=X.shape[0]onedal_params={"fit_intercept":self.fit_intercept,"copy_X":self.copy_X}ifnothasattr(self,"_onedal_estimator"):self._onedal_estimator=self._onedal_incremental_linear(**onedal_params)self._onedal_estimator.partial_fit(X,y,queue=queue)self._need_to_finalize=Trueifdaal_check_version((2025,"P",200)):def_onedal_validate_underdetermined(self,n_samples,n_features):passelse:def_onedal_validate_underdetermined(self,n_samples,n_features):is_underdetermined=n_samples<n_features+int(self.fit_intercept)ifis_underdetermined:raiseValueError("Not enough samples for oneDAL")def_onedal_finalize_fit(self,queue=None):asserthasattr(self,"_onedal_estimator")self._onedal_validate_underdetermined(self.n_samples_seen_,self.n_features_in_)self._onedal_estimator.finalize_fit(queue=queue)self._need_to_finalize=Falsedef_onedal_fit(self,X,y,queue=None):ifsklearn_check_version("1.2"):self._validate_params()ifsklearn_check_version("1.0"):X,y=validate_data(self,X,y,dtype=[np.float64,np.float32],copy=self.copy_X,multi_output=True,ensure_2d=True,)else:X=check_array(X,dtype=[np.float64,np.float32],copy=self.copy_X,)y=check_array(y,dtype=[np.float64,np.float32],copy=False,ensure_2d=False,)n_samples,n_features=X.shapeself._onedal_validate_underdetermined(n_samples,n_features)ifself.batch_sizeisNone:self.batch_size_=5*n_featureselse:self.batch_size_=self.batch_sizeself.n_samples_seen_=0ifhasattr(self,"_onedal_estimator"):self._onedal_estimator._reset()forbatchingen_batches(n_samples,self.batch_size_):X_batch,y_batch=X[batch],y[batch]self._onedal_partial_fit(X_batch,y_batch,check_input=False,queue=queue)ifsklearn_check_version("1.2"):self._validate_params()# finite check occurs on onedal sideself.n_features_in_=n_featuresifn_samples==1:warnings.warn("Only one sample available. You may want to reshape your data array")self._onedal_finalize_fit(queue=queue)returnself@propertydefintercept_(self):ifhasattr(self,"_onedal_estimator"):ifself._need_to_finalize:self._onedal_finalize_fit()returnself._onedal_estimator.intercept_else:raiseAttributeError(f"'{self.__class__.__name__}' object has no attribute 'intercept_'")@intercept_.setterdefintercept_(self,value):self.__dict__["intercept_"]=valueifhasattr(self,"_onedal_estimator"):self._onedal_estimator.intercept_=valuedelself._onedal_estimator._onedal_model@propertydefcoef_(self):ifhasattr(self,"_onedal_estimator"):ifself._need_to_finalize:self._onedal_finalize_fit()returnself._onedal_estimator.coef_else:raiseAttributeError(f"'{self.__class__.__name__}' object has no attribute 'coef_'")@coef_.setterdefcoef_(self,value):self.__dict__["coef_"]=valueifhasattr(self,"_onedal_estimator"):self._onedal_estimator.coef_=valuedelself._onedal_estimator._onedal_model
[docs]defpartial_fit(self,X,y,check_input=True):""" Incremental fit linear model with X and y. All of X and y is processed as a single batch. Parameters ---------- X : array-like of shape (n_samples, n_features) Training data, where ``n_samples`` is the number of samples and `n_features` is the number of features. y : array-like of shape (n_samples,) or (n_samples, n_targets) Target values, where ``n_samples`` is the number of samples and ``n_targets`` is the number of targets. Returns ------- self : object Returns the instance itself. """dispatch(self,"partial_fit",{"onedal":self.__class__._onedal_partial_fit,"sklearn":None,},X,y,check_input=check_input,)returnself
[docs]deffit(self,X,y):""" Fit the model with X and y, using minibatches of size ``batch_size``. Parameters ---------- X : array-like of shape (n_samples, n_features) Training data, where ``n_samples`` is the number of samples and ``n_features`` is the number of features. It is necessary for ``n_samples`` to be not less than ``n_features`` if ``fit_intercept`` is False and not less than ``n_features + 1`` if ``fit_intercept`` is True y : array-like of shape (n_samples,) or (n_samples, n_targets) Target values, where ``n_samples`` is the number of samples and ``n_targets`` is the number of targets. Returns ------- self : object Returns the instance itself. """dispatch(self,"fit",{"onedal":self.__class__._onedal_fit,"sklearn":None,},X,y,)returnself
[docs]@wrap_output_datadefpredict(self,X,y=None):""" Predict using the linear model. Parameters ---------- X : array-like or sparse matrix, shape (n_samples, n_features) Samples. y : Ignored Not used, present for API consistency by convention. Returns ------- C : array, shape (n_samples, n_targets) Returns predicted values. """check_is_fitted(self)returndispatch(self,"predict",{"onedal":self.__class__._onedal_predict,"sklearn":None,},X,)
@wrap_output_datadefscore(self,X,y,sample_weight=None):"""Return the coefficient of determination of the prediction. The coefficient of determination :math:`R^2` is defined as :math:`(1 - \\frac{u}{v})`, where :math:`u` is the residual sum of squares ``((y_true - y_pred)** 2).sum()`` and :math:`v` is the total sum of squares ``((y_true - y_true.mean()) ** 2).sum()``. The best possible score is 1.0 and it can be negative (because the model can be arbitrarily worse). A constant model that always predicts the expected value of `y`, disregarding the input features, would get a :math:`R^2` score of 0.0. Parameters ---------- X : array-like of shape (n_samples, n_features) Test samples. For some estimators this may be a precomputed kernel matrix or a list of generic objects instead with shape ``(n_samples, n_samples_fitted)``, where ``n_samples_fitted`` is the number of samples used in the fitting for the estimator. y : array-like of shape (n_samples,) or (n_samples, n_outputs) True values for `X`. sample_weight : array-like of shape (n_samples,), default=None Sample weights. Returns ------- score : float :math:`R^2` of ``self.predict(X)`` w.r.t. `y`. Notes ----- The :math:`R^2` score used when calling ``score`` on a regressor uses ``multioutput='uniform_average'`` from version 0.23 to keep consistent with default value of :func:`~sklearn.metrics.r2_score`. This influences the ``score`` method of all the multioutput regressors (except for :class:`~sklearn.multioutput.MultiOutputRegressor`). """check_is_fitted(self)returndispatch(self,"score",{"onedal":self.__class__._onedal_score,"sklearn":None,},X,y,sample_weight=sample_weight,)