{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Intel® Extension for Scikit-learn RandomForestClassifier for rain in Australia dataset\n",
"To predict will it rain the next day."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from timeit import default_timer as timer\n",
"from IPython.display import HTML\n",
"import warnings\n",
"\n",
"from sklearn.datasets import fetch_openml\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download the data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Date | \n",
" Location | \n",
" MinTemp | \n",
" MaxTemp | \n",
" Rainfall | \n",
" Evaporation | \n",
" Sunshine | \n",
" WindGustDir | \n",
" WindGustSpeed | \n",
" WindDir9am | \n",
" ... | \n",
" Humidity9am | \n",
" Humidity3pm | \n",
" Pressure9am | \n",
" Pressure3pm | \n",
" Cloud9am | \n",
" Cloud3pm | \n",
" Temp9am | \n",
" Temp3pm | \n",
" RainToday | \n",
" RainTomorrow | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 2008-12-01 | \n",
" Albury | \n",
" 13.4 | \n",
" 22.9 | \n",
" 0.6 | \n",
" NaN | \n",
" NaN | \n",
" W | \n",
" 44.0 | \n",
" W | \n",
" ... | \n",
" 71.0 | \n",
" 22.0 | \n",
" 1007.7 | \n",
" 1007.1 | \n",
" 8.0 | \n",
" NaN | \n",
" 16.9 | \n",
" 21.8 | \n",
" No | \n",
" No | \n",
"
\n",
" \n",
" 1 | \n",
" 2008-12-02 | \n",
" Albury | \n",
" 7.4 | \n",
" 25.1 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" WNW | \n",
" 44.0 | \n",
" NNW | \n",
" ... | \n",
" 44.0 | \n",
" 25.0 | \n",
" 1010.6 | \n",
" 1007.8 | \n",
" NaN | \n",
" NaN | \n",
" 17.2 | \n",
" 24.3 | \n",
" No | \n",
" No | \n",
"
\n",
" \n",
" 2 | \n",
" 2008-12-03 | \n",
" Albury | \n",
" 12.9 | \n",
" 25.7 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" WSW | \n",
" 46.0 | \n",
" W | \n",
" ... | \n",
" 38.0 | \n",
" 30.0 | \n",
" 1007.6 | \n",
" 1008.7 | \n",
" NaN | \n",
" 2.0 | \n",
" 21.0 | \n",
" 23.2 | \n",
" No | \n",
" No | \n",
"
\n",
" \n",
" 3 | \n",
" 2008-12-04 | \n",
" Albury | \n",
" 9.2 | \n",
" 28.0 | \n",
" 0.0 | \n",
" NaN | \n",
" NaN | \n",
" NE | \n",
" 24.0 | \n",
" SE | \n",
" ... | \n",
" 45.0 | \n",
" 16.0 | \n",
" 1017.6 | \n",
" 1012.8 | \n",
" NaN | \n",
" NaN | \n",
" 18.1 | \n",
" 26.5 | \n",
" No | \n",
" No | \n",
"
\n",
" \n",
" 4 | \n",
" 2008-12-05 | \n",
" Albury | \n",
" 17.5 | \n",
" 32.3 | \n",
" 1.0 | \n",
" NaN | \n",
" NaN | \n",
" W | \n",
" 41.0 | \n",
" ENE | \n",
" ... | \n",
" 82.0 | \n",
" 33.0 | \n",
" 1010.8 | \n",
" 1006.0 | \n",
" 7.0 | \n",
" 8.0 | \n",
" 17.8 | \n",
" 29.7 | \n",
" No | \n",
" No | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 23 columns
\n",
"
"
],
"text/plain": [
" Date Location MinTemp MaxTemp Rainfall Evaporation Sunshine \\\n",
"0 2008-12-01 Albury 13.4 22.9 0.6 NaN NaN \n",
"1 2008-12-02 Albury 7.4 25.1 0.0 NaN NaN \n",
"2 2008-12-03 Albury 12.9 25.7 0.0 NaN NaN \n",
"3 2008-12-04 Albury 9.2 28.0 0.0 NaN NaN \n",
"4 2008-12-05 Albury 17.5 32.3 1.0 NaN NaN \n",
"\n",
" WindGustDir WindGustSpeed WindDir9am ... Humidity9am Humidity3pm \\\n",
"0 W 44.0 W ... 71.0 22.0 \n",
"1 WNW 44.0 NNW ... 44.0 25.0 \n",
"2 WSW 46.0 W ... 38.0 30.0 \n",
"3 NE 24.0 SE ... 45.0 16.0 \n",
"4 W 41.0 ENE ... 82.0 33.0 \n",
"\n",
" Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm RainToday \\\n",
"0 1007.7 1007.1 8.0 NaN 16.9 21.8 No \n",
"1 1010.6 1007.8 NaN NaN 17.2 24.3 No \n",
"2 1007.6 1008.7 NaN 2.0 21.0 23.2 No \n",
"3 1017.6 1012.8 NaN NaN 18.1 26.5 No \n",
"4 1010.8 1006.0 7.0 8.0 17.8 29.7 No \n",
"\n",
" RainTomorrow \n",
"0 No \n",
"1 No \n",
"2 No \n",
"3 No \n",
"4 No \n",
"\n",
"[5 rows x 23 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = fetch_openml(data_id=46315, as_frame=True)\n",
"df = data.frame\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Explore the data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(145460, 23)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Show the dimensions of the dataset\n",
"df.shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"RangeIndex: 145460 entries, 0 to 145459\n",
"Data columns (total 23 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Date 145460 non-null object \n",
" 1 Location 145460 non-null object \n",
" 2 MinTemp 143975 non-null float64\n",
" 3 MaxTemp 144199 non-null float64\n",
" 4 Rainfall 142199 non-null float64\n",
" 5 Evaporation 82670 non-null float64\n",
" 6 Sunshine 75625 non-null float64\n",
" 7 WindGustDir 135134 non-null object \n",
" 8 WindGustSpeed 135197 non-null float64\n",
" 9 WindDir9am 134894 non-null object \n",
" 10 WindDir3pm 141232 non-null object \n",
" 11 WindSpeed9am 143693 non-null float64\n",
" 12 WindSpeed3pm 142398 non-null float64\n",
" 13 Humidity9am 142806 non-null float64\n",
" 14 Humidity3pm 140953 non-null float64\n",
" 15 Pressure9am 130395 non-null float64\n",
" 16 Pressure3pm 130432 non-null float64\n",
" 17 Cloud9am 89572 non-null float64\n",
" 18 Cloud3pm 86102 non-null float64\n",
" 19 Temp9am 143693 non-null float64\n",
" 20 Temp3pm 141851 non-null float64\n",
" 21 RainToday 142199 non-null object \n",
" 22 RainTomorrow 142193 non-null object \n",
"dtypes: float64(16), object(7)\n",
"memory usage: 25.5+ MB\n"
]
}
],
"source": [
"# Show the summary of the dataset\n",
"df.info()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Missing Values | \n",
" Percentage (%) | \n",
"
\n",
" \n",
" \n",
" \n",
" Date | \n",
" 0 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" Location | \n",
" 0 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" MinTemp | \n",
" 1485 | \n",
" 1.020899 | \n",
"
\n",
" \n",
" MaxTemp | \n",
" 1261 | \n",
" 0.866905 | \n",
"
\n",
" \n",
" Rainfall | \n",
" 3261 | \n",
" 2.241853 | \n",
"
\n",
" \n",
" Evaporation | \n",
" 62790 | \n",
" 43.166506 | \n",
"
\n",
" \n",
" Sunshine | \n",
" 69835 | \n",
" 48.009762 | \n",
"
\n",
" \n",
" WindGustDir | \n",
" 10326 | \n",
" 7.098859 | \n",
"
\n",
" \n",
" WindGustSpeed | \n",
" 10263 | \n",
" 7.055548 | \n",
"
\n",
" \n",
" WindDir9am | \n",
" 10566 | \n",
" 7.263853 | \n",
"
\n",
" \n",
" WindDir3pm | \n",
" 4228 | \n",
" 2.906641 | \n",
"
\n",
" \n",
" WindSpeed9am | \n",
" 1767 | \n",
" 1.214767 | \n",
"
\n",
" \n",
" WindSpeed3pm | \n",
" 3062 | \n",
" 2.105046 | \n",
"
\n",
" \n",
" Humidity9am | \n",
" 2654 | \n",
" 1.824557 | \n",
"
\n",
" \n",
" Humidity3pm | \n",
" 4507 | \n",
" 3.098446 | \n",
"
\n",
" \n",
" Pressure9am | \n",
" 15065 | \n",
" 10.356799 | \n",
"
\n",
" \n",
" Pressure3pm | \n",
" 15028 | \n",
" 10.331363 | \n",
"
\n",
" \n",
" Cloud9am | \n",
" 55888 | \n",
" 38.421559 | \n",
"
\n",
" \n",
" Cloud3pm | \n",
" 59358 | \n",
" 40.807095 | \n",
"
\n",
" \n",
" Temp9am | \n",
" 1767 | \n",
" 1.214767 | \n",
"
\n",
" \n",
" Temp3pm | \n",
" 3609 | \n",
" 2.481094 | \n",
"
\n",
" \n",
" RainToday | \n",
" 3261 | \n",
" 2.241853 | \n",
"
\n",
" \n",
" RainTomorrow | \n",
" 3267 | \n",
" 2.245978 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Missing Values Percentage (%)\n",
"Date 0 0.000000\n",
"Location 0 0.000000\n",
"MinTemp 1485 1.020899\n",
"MaxTemp 1261 0.866905\n",
"Rainfall 3261 2.241853\n",
"Evaporation 62790 43.166506\n",
"Sunshine 69835 48.009762\n",
"WindGustDir 10326 7.098859\n",
"WindGustSpeed 10263 7.055548\n",
"WindDir9am 10566 7.263853\n",
"WindDir3pm 4228 2.906641\n",
"WindSpeed9am 1767 1.214767\n",
"WindSpeed3pm 3062 2.105046\n",
"Humidity9am 2654 1.824557\n",
"Humidity3pm 4507 3.098446\n",
"Pressure9am 15065 10.356799\n",
"Pressure3pm 15028 10.331363\n",
"Cloud9am 55888 38.421559\n",
"Cloud3pm 59358 40.807095\n",
"Temp9am 1767 1.214767\n",
"Temp3pm 3609 2.481094\n",
"RainToday 3261 2.241853\n",
"RainTomorrow 3267 2.245978"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Check the missing values and the percentage of missing values in each column\n",
"missing_values = df.isnull().sum()\n",
"missing_values_percentage = missing_values / df.shape[0] * 100\n",
"missing_values_df = pd.DataFrame({\n",
" 'Missing Values': missing_values,\n",
" 'Percentage (%)': missing_values_percentage\n",
"})\n",
"missing_values_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Preprocessing"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(145460, 19)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Drop columns with more than 30% missing values\n",
"df = df.dropna(thresh=df.shape[0]*0.7, axis=1)\n",
"df.shape"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(142193, 19)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Drop rows with missing target value\n",
"df = df.dropna(subset=['RainTomorrow'])\n",
"df.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Encode the target variable\n",
"df['RainTomorrow'] = df['RainTomorrow'].map({'No': 0, 'Yes': 1})"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Index: 142193 entries, 0 to 145458\n",
"Data columns (total 22 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Date 142193 non-null datetime64[ns]\n",
" 1 Location 142193 non-null object \n",
" 2 MinTemp 141556 non-null float64 \n",
" 3 MaxTemp 141871 non-null float64 \n",
" 4 Rainfall 140787 non-null float64 \n",
" 5 WindGustDir 132863 non-null object \n",
" 6 WindGustSpeed 132923 non-null float64 \n",
" 7 WindDir9am 132180 non-null object \n",
" 8 WindDir3pm 138415 non-null object \n",
" 9 WindSpeed9am 140845 non-null float64 \n",
" 10 WindSpeed3pm 139563 non-null float64 \n",
" 11 Humidity9am 140419 non-null float64 \n",
" 12 Humidity3pm 138583 non-null float64 \n",
" 13 Pressure9am 128179 non-null float64 \n",
" 14 Pressure3pm 128212 non-null float64 \n",
" 15 Temp9am 141289 non-null float64 \n",
" 16 Temp3pm 139467 non-null float64 \n",
" 17 RainToday 140787 non-null object \n",
" 18 RainTomorrow 142193 non-null int64 \n",
" 19 Year 142193 non-null int64 \n",
" 20 Month 142193 non-null int64 \n",
" 21 Day 142193 non-null int64 \n",
"dtypes: datetime64[ns](1), float64(12), int64(4), object(5)\n",
"memory usage: 25.0+ MB\n"
]
}
],
"source": [
"# Split the Date column into Year, Month, and Day\n",
"df['Date'] = pd.to_datetime(df['Date'])\n",
"df['Year'] = df['Date'].dt.year.astype('int64')\n",
"df['Month'] = df['Date'].dt.month.astype('int64')\n",
"df['Day'] = df['Date'].dt.day.astype('int64')\n",
"\n",
"df.info()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Define the features and the target\n",
"X = df.drop(columns=['RainTomorrow', 'Date'])\n",
"y = df['RainTomorrow']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Numerical Columns: ['MinTemp', 'MaxTemp', 'Rainfall', 'WindGustSpeed', 'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm', 'Pressure9am', 'Pressure3pm', 'Temp9am', 'Temp3pm', 'Year', 'Month', 'Day']\n",
"Categorical Columns: ['Location', 'WindGustDir', 'WindDir9am', 'WindDir3pm', 'RainToday']\n"
]
}
],
"source": [
"# Identify the numerical and categorical columns\n",
"num_columns = X.select_dtypes(include=['int64', 'float64']).columns\n",
"cat_columns = X.select_dtypes(include=['object']).columns\n",
"\n",
"print(f'Numerical Columns: {list(num_columns)}')\n",
"print(f'Categorical Columns: {list(cat_columns)}')\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Preprocess the numerical features\n",
"\n",
"# Impute missing values with the mean\n",
"imputer_num = SimpleImputer(strategy='mean')\n",
"X[num_columns] = imputer_num.fit_transform(X[num_columns])\n",
"\n",
"# Scale the numerical columns\n",
"scaler = StandardScaler()\n",
"X[num_columns] = scaler.fit_transform(X[num_columns])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# Preprocess the categorical features\n",
"\n",
"# Impute missing values with the most frequent value\n",
"imputer_cat = SimpleImputer(strategy='most_frequent')\n",
"X[cat_columns] = imputer_cat.fit_transform(X[cat_columns])\n",
"\n",
"# Label encode the categorical columns\n",
"encoder = LabelEncoder()\n",
"for col in cat_columns:\n",
" X[col] = encoder.fit_transform(X[col])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Index: 142193 entries, 0 to 145458\n",
"Data columns (total 22 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Date 142193 non-null datetime64[ns]\n",
" 1 Location 142193 non-null object \n",
" 2 MinTemp 141556 non-null float64 \n",
" 3 MaxTemp 141871 non-null float64 \n",
" 4 Rainfall 140787 non-null float64 \n",
" 5 WindGustDir 132863 non-null object \n",
" 6 WindGustSpeed 132923 non-null float64 \n",
" 7 WindDir9am 132180 non-null object \n",
" 8 WindDir3pm 138415 non-null object \n",
" 9 WindSpeed9am 140845 non-null float64 \n",
" 10 WindSpeed3pm 139563 non-null float64 \n",
" 11 Humidity9am 140419 non-null float64 \n",
" 12 Humidity3pm 138583 non-null float64 \n",
" 13 Pressure9am 128179 non-null float64 \n",
" 14 Pressure3pm 128212 non-null float64 \n",
" 15 Temp9am 141289 non-null float64 \n",
" 16 Temp3pm 139467 non-null float64 \n",
" 17 RainToday 140787 non-null object \n",
" 18 RainTomorrow 142193 non-null int64 \n",
" 19 Year 142193 non-null int64 \n",
" 20 Month 142193 non-null int64 \n",
" 21 Day 142193 non-null int64 \n",
"dtypes: datetime64[ns](1), float64(12), int64(4), object(5)\n",
"memory usage: 25.0+ MB\n"
]
}
],
"source": [
"# Ensure all columns are numerical and no missing values\n",
"df.info()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((113754, 20), (28439, 20), (113754,), (28439,))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Split the dataset into training and testing sets\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"X_train.shape, X_test.shape, y_train.shape, y_test.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Patch original Scikit-learn with Intel® Extension for Scikit-learn\n",
"Intel® Extension for Scikit-learn (previously known as daal4py) contains drop-in replacement functionality for the stock Scikit-learn package. You can take advantage of the performance optimizations of Intel® Extension for Scikit-learn by adding just two lines of code before the usual Scikit-learn imports:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Intel(R) Extension for Scikit-learn* enabled (https://github.com/uxlfoundation/scikit-learn-intelex)\n"
]
}
],
"source": [
"from sklearnex import patch_sklearn\n",
"patch_sklearn()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training of the RandomForestClassifier with Intel® Extension for Scikit-learn for Rain in Australia dataset"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Intel® extension for Scikit-learn Training Time: 9.439 seconds\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"params = {\n",
" 'n_estimators': 1000,\n",
" 'criterion': 'gini',\n",
" 'max_features': 'sqrt',\n",
" 'n_jobs': -1\n",
"}\n",
"start = timer()\n",
"patched_model = RandomForestClassifier(**params).fit(X_train, y_train)\n",
"patched_train_time = timer() - start\n",
"\n",
"print(f\"Intel® extension for Scikit-learn Training Time: {patched_train_time:.3f} seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Predict and get a result of the RandomForestClassifier algorithm with Intel® Extension for Scikit-learn"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Intel® extension for Scikit-learn Accuracy: 0.8551\n"
]
}
],
"source": [
"patched_y_pred = patched_model.predict(X_test)\n",
"patched_accuracy = accuracy_score(y_test, patched_y_pred)\n",
"\n",
"print(f\"Intel® extension for Scikit-learn Accuracy: {patched_accuracy:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train the same algorithm with original Scikit-learn\n",
"\n",
"In order to cancel optimizations, we use *unpatch_sklearn* and reimport the class RandomForestClassifier."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from sklearnex import unpatch_sklearn\n",
"unpatch_sklearn()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training of the RandomForestClassifier with original Scikit-learn for Rain in Australia dataset"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original Scikit-learn Training Time: 47.955 seconds\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"start = timer()\n",
"ori_model = RandomForestClassifier(**params).fit(X_train, y_train)\n",
"ori_train_time = timer() - start\n",
"\n",
"print(f\"Original Scikit-learn Training Time: {ori_train_time:.3f} seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Predict and get a result of the RandomForestClassifier algorithm with original Scikit-learn"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original Scikit-learn Accuracy: 0.8549\n"
]
}
],
"source": [
"ori_y_pred = ori_model.predict(X_test)\n",
"ori_accuracy = accuracy_score(y_test, ori_y_pred)\n",
"\n",
"print(f\"Original Scikit-learn Accuracy: {ori_accuracy:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Comparison\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Original | \n",
" Patched | \n",
" Improvement (%) | \n",
"
\n",
" \n",
" \n",
" \n",
" Accuracy | \n",
" 0.8549 | \n",
" 0.8551 | \n",
" 0.02 | \n",
"
\n",
" \n",
" Training Time (s) | \n",
" 47.9546 | \n",
" 9.4395 | \n",
" -80.32 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Original Patched Improvement (%)\n",
"Accuracy 0.8549 0.8551 0.02\n",
"Training Time (s) 47.9546 9.4395 -80.32"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"compare_df = pd.DataFrame({\n",
" 'Original': [ori_accuracy, ori_train_time],\n",
" 'Patched': [patched_accuracy, patched_train_time]\n",
"}, index=['Accuracy', 'Training Time (s)'])\n",
"\n",
"for col in compare_df.columns:\n",
" compare_df[col] = compare_df[col].round(4)\n",
"\n",
"# Calculate the improvement in percentage\n",
"compare_df['Improvement (%)'] = (compare_df['Patched'] - compare_df['Original']) / compare_df['Original'] * 100\n",
"compare_df['Improvement (%)'] = compare_df['Improvement (%)'].round(2)\n",
"\n",
"compare_df"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Compare Accuracy of patched Scikit-learn and original
Accuracy of patched Scikit-learn: 0.8550933577129998
Accuracy of unpatched Scikit-learn: 0.8548823798305144
Metrics ratio: 1.0002467917077986
With Scikit-learn-intelex patching you can:
- Use your Scikit-learn code for training and prediction with minimal changes (a couple of lines of code);
- Get comparable model quality
- Get a 5.1x speedup.
"
],
"text/plain": [
""
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HTML(\n",
" f\"Compare Accuracy of patched Scikit-learn and original
\"\n",
" f\"Accuracy of patched Scikit-learn: {patched_accuracy}
\"\n",
" f\"Accuracy of unpatched Scikit-learn: {ori_accuracy}
\"\n",
" f\"Metrics ratio: {patched_accuracy/ori_accuracy}
\"\n",
" f\"With Scikit-learn-intelex patching you can:
\"\n",
" f\"\"\n",
" f\"- Use your Scikit-learn code for training and prediction with minimal changes (a couple of lines of code);
\"\n",
" f\"- Get comparable model quality
\"\n",
" f\"- Get a {(ori_train_time/patched_train_time):.1f}x speedup.
\"\n",
" f\"
\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}