# Code to accompany lecture 4, 36-402 Data Analysis, spring 2011 # Multi-fold cross-validation for univariate kernel regression # Inputs: vector of input feature values (x) # vector of response values (y) # vector of bandwidths (bandwidths) # number of folds (num.folds) # Calls: npreg from the np package # Output: list with components of: best bandwidth, vector of # cross-validated MSEs, array of MSE for each bandwidth on each fold cv_bws_npreg <- function(x,y,bandwidths=(1:50)/50,num.folds=10) { require(np) n = length(x) # Sanity-check inputs require(np) n <- length(x) stopifnot(n> 1, length(y) == n) stopifnot(length(bandwidths) > 1) stopifnot(num.folds > 0, num.folds==trunc(num.folds)) # Make a matrix to store MSEs for each fold/bandwidth combination fold_MSEs <- matrix(0,nrow=num.folds,ncol=length(bandwidths)) # Name the columns after bandwidths for easy reference later # coerces the numerical bandwidths into character strings colnames(fold_MSEs) = bandwidths # Divide the data randomly into "folds" (non-overlapping testing sets) # Assign each data point to one of the "folds", in order (so that size is # as nearly equal as possible) case.folds <- rep(1:num.folds,length.out=n) # Now randomly permute the order (see help(sample) for how this works) case.folds <- sample(case.folds) for (fold in 1:num.folds) { train.rows = which(case.folds==fold) # Separate training and testing sets x.train = x[train.rows] y.train = y[train.rows] x.test = x[-train.rows] y.test = y[-train.rows] # Cycle over bandwidths for (bw in bandwidths) { fit <- npreg(txdat=x.train,tydat=y.train, exdat=x.test,eydat=y.test,bws=bw) # Invoking npreg this way tells it to fit on the training data (txdat, # tydat), but evaluate on different data (exdat, eydat) # See help(npreg) for more fold_MSEs[fold,paste(bw)] <- fit$MSE # Here MSE is the MSE on the evaluation data # paste(bw): turns numerical bandwidth to type character, so R knows # it's the name of a column, not a column index } } # Average over folds CV_MSEs = colMeans(fold_MSEs) # Find the best bandwidth best.bw = bandwidths[which.min(CV_MSEs)] return(list(best.bw=best.bw,CV_MSEs=CV_MSEs,fold_MSEs=fold_MSEs)) }