The growing complexity of machine learning algorithms limits the ability to understand what the model has learned or why a given prediction was made, acting as a barrier to the adoption of machine learning in many enterprises. Machine learning explainability addresses these issues by explaining and interpreting machine learning model and its predictions and thereby improving trust in the results from a ML model.
Machine learning explainability can help machine learning developers to
It can help users of machine learning algorithms to understand why the model made a certain prediction. E.g., Why was my bank loan denied?
Development and Deployment of the Loan Approval Model
MySQL HeatWave ML supports model explanations and prediction explanations. Model explanations are created during model training, and it identifies the features that are most important to the model overall. Prediction explanations can explain the significant features of one or more rows of input data or explain all the rows within a table. Following are the highlights of HeatWave ML explanations:
Quality
Performance and Scalability
Interpretability
HeatWave ML routines referred in this blog are described below. Please refer to MySQL HeatWave ML documentation for the details.
CALL sys.ML_TRAIN ('table_name', 'target_column_name', [options], model_handle_variable);
Running the ML_TRAIN routine on a labeled training dataset produces a trained machine learning model as well as model explanations using an efficient Permutation Importance explanation technique.
CALL sys.ML_MODEL_LOAD(model_handle, user);
The ML_MODEL_LOAD routine loads a model from the model catalog. A model remains loaded until the model is unloaded using the ML_MODEL_UNLOAD routine or until HeatWave ML is restarted by a HeatWave Cluster restart.
SELECT sys.ML_EXPLAIN_ROW(input_data, model_handle);
The ML_EXPLAIN_ROW routine generates explanations for the predicted values for one or more rows of unlabeled data.
CALL sys.ML_EXPLAIN_TABLE(table_name, model_handle, output_table_name);
ML_EXPLAIN_TABLE explains predictions for an entire table of unlabeled data and saves results to an output table.
The examples of HeatWave ML explanation provided in this document are based on the Census dataset from the UCI Machine Learning Repository. Prediction task is to determine whether a person makes over $50K a year based on various attributes provided in the dataset.
User needs to create a HeatWave cluster and use the Census dataset to try the explanation functionality described in this blog. We have made the scripts to train & evaluate model based on this and other data sets at, https://github.com/oracle-samples/heatwave-ml to make it easy for you to try the MySQL HeatWave ML functionality.
As explained above, a model explanation is generated when user trains a machine learning model using the ML_TRAIN routine. The model explanation is stored in the model_explanation column in the MODEL_CATALOG table and can be accessed using a query provided below.
SELECT model_explanation FROM ML_SCHEMA_user1.MODEL_CATALOG WHERE model_handle=@census_model;
A model explanation helps the user identify the features that are most important to the model overall. Feature importance is presented as a numerical value ranging from 0 to 1. Higher values signify higher feature importance, lower values signify lower feature importance, and a 0 value means that the feature does not influence the model.
{"age": 0.0292, "sex": 0.0023, "race": 0.0019, "fnlwgt": 0.0038, "education": 0.0008, "workclass": 0.0068, "occupation": 0.0223, "capital-gain": 0.0479, "capital-loss": 0.0117, "relationship": 0.0234, "education-num": 0.0352, "hours-per-week": 0.0148, "marital-status": 0.024, "native-country": 0.0}
This model explanation shows that capital-gain and education-num are the most significant features of this model. Features such as native-country and education are not significant.
Explain prediction on individual row
The sys.ML_EXPLAIN_ROW routine creates in-line explanations for one or more rows of input data. The explanations are conveyed through weights which show how much each attribute influenced the final prediction, higher values mean a stronger influence.
mysql> CALL sys.ML_MODEL_LOAD(@model, NULL);
Query OK, 0 rows affected (1.12 sec)
mysql> SELECT sys.ML_EXPLAIN_ROW(‘{“index”: 1,”age”: 38,”workclass”: “Private”,”fnlwgt”: 89814,”education”: “HS-grad”,”education-num”:
9,”marital-status”: “Married-civ-spouse”,”occupation”: “Farming-fishing”,”relationship”: “Husband”,”race”: “White”,”sex”: “Male”,”capital-gain”: 0,”capital-loss”: 0,”hours-per-week”: 50,”native-country”: “United-States”}’, @model);
{"age": 38, "sex": ”Male", ”race": ”White", ”index": 1, "fnlwgt": 89814, "education": ”HS-grad", ”workclass": ”Private", ”Prediction": ”<=50K", ”occupation": ”Farming-fishing", ”capital-gain": 0, "capital-loss": 0, "relationship": ”Husband", ”education-num": 9, "hours-per-week": 50, "marital-status": ”Married-civ-spouse", ”native-country": ”United-States", ”age_attribution": 0.2234, "sex_attribution": 0.0241, "race_attribution": 0.0011, "index_attribution": 0.0, "fnlwgt_attribution": 0.003, "education_attribution": 0.0, "workclass_attribution": 0.0126, "occupation_attribution": 0.1111, "capital-gain_attribution": 0.0, "capital-loss_attribution": 0.0, "relationship_attribution": 0.0928, "education-num_attribution": 0.1305, "hours-per-week_attribution": 0.1806, "marital-status_attribution": 0.0676, "native-country_attribution": 0.0001} |
1 row in set (4.41 sec)
This prediction explanation shows that this individual earns <=50K. The most significant features for this prediction are age and hours-per-week. Features such as capital-gain, capital-loss and education are not significant.
Explain prediction on a table
The sys.ML_EXPLAIN_TABLE routine creates and populates a new table with features, predictions, and explanations for each row of the input table. Explanations across rows are done in parallel. The loaded model's training columns must match the ML_EXPLAIN_TABLE input columns. Note that the input table should contain the same column names as the table used for training the model, except without the target column.
mysql> CALL sys.ML_MODEL_LOAD(@model, NULL);
Query OK, 0 rows affected (1.12 sec)
mysql> CALL sys.ML_EXPLAIN_TABLE('mlcorpus_v4.census_test_naive', @model, 'mlcorpus_v4.census_explanations');
Query OK, 0 rows affected (12.95 sec)
ML_EXPLAIN_TABLE function created the census_explanations table with each row showing the earning prediction and significance of various features of the model.
In summary, model explanations identify the features that are most important to the model overall. Prediction explanations can either explain one or more rows of input data or table with multiple rows. MySQL HeatWave ML Explanations can help machine learning developers to better understand which features are important so that they can improve the quality of the model and can help users of machine learning algorithms to understand why the model made a certain prediction.
Salil Pradhan is a Product Manager in MySQL HeatWave team. His interests include distributed data processing, machine learning, cloud computing, middleware technologies as well as application areas such as Marketing Automation and Supply Chain Management.
Next Post