Monday May 19, 2014

Model cross-validation with ore.CV()

In this blog post we illustrate how to use Oracle R Enterprise for performing cross-validation of regression and classification models. We describe a new utility R function ore.CV that leverages features of Oracle R Enterprise and is available for download and use.

Predictive models are usually built on given data and verified on held-aside or unseen data. Cross-validation is a model improvement technique that avoids the limitations of a single train-and-test experiment by building and testing multiple models via repeated sampling from the available data. It's purpose is to offer a better insight into how well the model would generalize to new data and avoid over-fitting and deriving wrong conclusions from misleading peculiarities of the seen data.

In a k-fold cross-validation the data is partitioned into k (roughly) equal size subsets. One of the subsets is retained for testing and the remaining k-1 subsets are used for training. The process is repeated k times with each of the k subsets serving exactly once as testing data. Thus, all observations in the original data set are used for both training and testing.

The choice of k depends, in practice on the size n of the data set. For large data, k=3 could be sufficient. For very small data, the extreme case where k=n, leave-one-out cross-validation (LOOCV) would use a single observation from the original sample as testing data and the remaining observations as training data. Common choices are k=10 or k=5.

For a select set of algorithms and cases, the function ore.CV performs cross-validation for models generated by ORE regression and classification functions using in-databse data. ORE embedded R execution is leveraged to support cross-validation also for models built with vanilla R functions.

Usage

ore.CV(funType, function, formula, dataset, nFolds=<nb.folds>, fun.args=NULL, pred.args=NULL, pckg.lst=NULL)
  • funType - "regression" or "classification"
  • function - ORE predictive modeling functions for regression & classification or R function (regression only)
  • formula - object of class "formula"
  • dataset - name of the ore.frame
  • nFolds - number of folds
  • fun.args - list of supplementary arguments for 'function'
  • pred.args - list of supplementary arguments for 'predict'. Must be consistent with the model object/model generator 'function'.
  • pckg.lst - list of packages to be loaded by the DB R engine for embedded execution.
The set of functions supported for ORE include:
  • ore.lm
  • ore.stepwise
  • ore.neural
  • ore.glm
  • ore.odmDT
  • ore.odmSVM
  • ore.odmGLM
  • ore.odmNB
The set of functions supported for R include:
  • lm
  • glm
  • svm
Note: The 'ggplot' and 'reshape' packages are required on the R client side for data post-processing and plotting (classification CV).

Examples

In the following examples, we illustrate various ways to invoke ore.CV using some datasets we have seen in previous posts. The datasets can be created as ore.frame objects using:
 
IRIS <- ore.push(iris)
LONGLEY <- ore.push(longley)
library(rpart)
KYPHOSIS <- ore.push(kyphosis)
library(PASWR)
TITANIC3 <- ore.push(titanic3)
MTCARS <- pore.push(mtcars)
(A) Cross-validation for models generated with ORE functions.
 
# Basic specification
ore.CV("regression","ore.lm",Sepal.Length~.-Species,"IRIS",nFolds=5)
ore.CV("regression","ore.neural",Employed~GNP+Population+Year,
            "LONGLEY",nFolds=5)

#Specification of function arguments
ore.CV("regression","ore.stepwise",Employed~.,"LONGLEY",nFolds=5,
            fun.args= list(add.p=0.15,drop.p=0.15))
ore.CV("regression","ore.odmSVM",Employed~GNP+Population+Year,
             "LONGLEY",nFolds=5, fun.args="regression")

#Specification of function arguments and prediction arguments
ore.CV("classification","ore.glm",Kyphosis~.,"KYPHOSIS",nFolds=5,
             fun.args=list(family=binomial()),pred.args=list(type="response"))
ore.CV("classification","ore.odmGLM",Kyphosis~.,"KYPHOSIS",nFolds=5,
            fun.args= list(type="logistic"),pred.args=list(type="response"))
 
(B) Cross-validation for models generated with R functions via the ORE embedded execution mechanism.

ore.CV("regression","lm",mpg~cyl+disp+hp+drat+wt+qsec,"MTCARS",nFolds=3)
ore.CV("regression","svm",Sepal.Length~.-Species,"IRIS",nFolds=5,
             fun.args=list(type="eps-regression"), pckg.lst=c("e1071")) 


Restrictions

  • The signature of the model generator ‘function’ must be of the following type: function(formula,data,...). For example, functions like, ore.stepwise, ore.odmGLM and lm are supported but the R step(object,scope,...) function for AIC model selection via the stepwise algorithm, does not satisfy this requirement.
  • The model validation process requires the prediction function to return a (1-dimensional) vector with the predicted values. If the (default) returned object is different the requirement must be met by providing an appropriate argument through ‘pred.args’. For example, for classification with ore.glm or ore.odmGLM the user should specify pred.args=list(type="response").
  • Cross-validation of classification models via embedded R execution of vanilla R functions is not supported yet.
  • Remark: Cross-validation is not a technique intended for large data as the cost of multiple model training and testing can become prohibitive. Moreover, with large data sets, it is possible to effectively produce an effective sampled train and test data set. The current ore.CV does not impose any restrictions on the size of the input and the user working with large data should use good judgment when choosing the model generator and the number of folds.

    Output

    The function ore.CV provides output on several levels: datastores to contain model results, plots, and text output.

    Datastores

    The results of each cross-validation run are saved into a datastore named dsCV_funTyp_data_Target_function_nFxx where funTyp, function, nF(=nFolds) have been described above and Target is the left-hand-side of the formula. For example, if one runs the ore.neural, ore.glm, and ore.odmNB-based cross-validation examples from above, the following three datastores are produced:
    
    R> ds <- ore.datastore(pattern="dsCV")
    R> print(ds)
    datastore.name object.count size creation.date description
    1 dsCV_classification_KYPHOSIS_Kyphosis_ore.glm_nF5 104480326 2014-04-30 18:19:55 <NA>
    2 dsCV_classification_TITANIC3_survived_ore.odmNB_nF5 10 592083 2014-04-30 18:21:35 <NA>
    3 dsCV_regression_LONGLEY_Employed_ore.neural_nF5 10 497204 2014-04-30 18:16:35 <NA>
    
    Each datastore contains the models and prediction tables for every fold. Every prediction table has 3 columns: the fold index together with the target variable/class and the predicted values. If we consider the example from above and examine the most recent datastore (the Naive Bayes classification CV), we would see:
    
    R> ds.last <- ds$datastore.name[which.max(as.numeric(ds$creation.date))]
    R> ore.datastoreSummary(name=ds.last)
    object.name class size length row.count col.count
    1 model.fold1 ore.odmNB 66138 9 NA NA
    2 model.fold2 ore.odmNB 88475 9 NA NA
    3 model.fold3 ore.odmNB 110598 9 NA NA
    4 model.fold4 ore.odmNB 133051 9 NA NA
    5 model.fold5 ore.odmNB 155366 9 NA NA
    6 test.fold1 ore.frame 7691 3 261 3
    7 test.fold2 ore.frame 7691 3 262 3
    8 test.fold3 ore.frame 7691 3 262 3
    9 test.fold4 ore.frame 7691 3 262 3
    10 test.fold5 ore.frame 7691 3 262 3
    
    

    Plots

    The following plots are generated automatically by ore.CV and saved in an automatically generated OUTPUT directory:

  • Regression: ore.CV compares predicted vs target values, root mean square error (RMSE) and relative error (RERR) boxplots per fold. The example below is based on 5-fold cross-validation with the ore.lm regression model for Sepal.Length ~.-Species using the ore.frame IRIS dataset.
  • Classification : ore.CV outputs a multi plot figure for classification metrics like Precision, Recall and F-measure. Each metrics is captured per target class (side-by-side barplots) and fold (groups of barplots). The example below is based on the 5-folds CV of the ore.odmSVM classification model for Species ~. using the ore.frame IRIS dataset.
  • Text output
    For classification problems, the confusion tables for each fold are saved in an ouput file residing in the OUTPUT directory together with a summary table displaying the precision, recall and F-measure metrics for every fold and predicted class.
    file.show("OUTDIR/tbl_CV_classification_IRIS_Species_ore.odmSVM_nF5")
    
    Confusion table for fold 1 :           
                 setosa versicolor virginica
      setosa          9          0         0
      versicolor      0         12         1
      virginica       0          1         7
    Confusion table for fold 2 :            
                 setosa versicolor virginica
      setosa          9          0         0
      versicolor      0          8         1
      virginica       0          2        10
    Confusion table for fold 3 :           
                 setosa versicolor virginica
      setosa         11          0         0
      versicolor      0         10         2
      virginica       0          0         7
    Confusion table for fold 4 :            
                 setosa versicolor virginica
      setosa          9          0         0
      versicolor      0         10         0
      virginica       0          2         9
    Confusion table for fold 5 :            
                 setosa versicolor virginica
      setosa         12          0         0
      versicolor      0          5         1
      virginica       0          0        12
    Accuracy, Recall & F-measure table per {class,fold}
       fold      class TP  m  n Precision Recall F_meas
    1     1     setosa  9  9  9     1.000  1.000  1.000
    2     1 versicolor 12 13 13     0.923  0.923  0.923
    3     1  virginica  7  8  8     0.875  0.875  0.875
    4     2     setosa  9  9  9     1.000  1.000  1.000
    5     2 versicolor  8  9 10     0.889  0.800  0.842
    6     2  virginica 10 12 11     0.833  0.909  0.870
    7     3     setosa 11 11 11     1.000  1.000  1.000
    8     3 versicolor 10 12 10     0.833  1.000  0.909
    9     3  virginica  7  7  9     1.000  0.778  0.875
    10    4     setosa  9  9  9     1.000  1.000  1.000
    11    4 versicolor 10 10 12     1.000  0.833  0.909
    12    4  virginica  9 11  9     0.818  1.000  0.900
    13    5     setosa 12 12 12     1.000  1.000  1.000
    14    5 versicolor  5  6  5     0.833  1.000  0.909
    15    5  virginica 12 12 13     1.000  0.923  0.960
    
              
    What's next
    Several extensions of ore.CV are possible involving sampling, parallel model training and testing, support for vanilla R classifiers, post-processing and output. More material for future posts.
    About

    The place for best practices, tips, and tricks for applying Oracle R Enterprise, Oracle R Distribution, ROracle, and Oracle R Advanced Analytics for Hadoop in both traditional and Big Data environments.

    Search

    Categories
    Archives
    « March 2015
    SunMonTueWedThuFriSat
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
        
           
    Today