Friday, April 18, 2014

Bias-variance trade off

In this post we discussed VC dimension and VC analysis, in this post let's look at the same thing with a slightly different perspective.

With VC analysis we decomposed $E_{out}$ into $E_{in}$ and $\Omega$, with bias-variance analysis we perform a different decomposition. ($E_{out} \; E_{in}$ can be interpreted both as a probability of making a wrong classification or as the amount of error.) Given one data set D from the universe U, after learning we get a final optimized hypothesis g, then the average amount of squared error is \begin{align*} E_x[ (g_D(x) - f(x))^2 ] \end{align*} where $f$ is the unknown targe function (in simulations this can be known, of course). When more data sets are sampled, we need take another average: \begin{align*} E_{out} &= E_D[ E_x[ (g_D(x) - f(x))^2 ] ] \end{align*} Here comes a very smart move: switch $E_D$ and $E_x$. Define $\overline{g}(x) = E_D(g_D(x))$, then: \begin{align*} E_{out} &= E_x[ E_D[ (g_D(x) -\overline{g}(x) + \overline{g}(x) - f(x))^2 ] ] \\ &= E_x[ E_D[ (g_D(x) -\overline{g}(x))^2 ] ] + E_x[ E_D[ (\overline{g}(x) - f(x))^2 ] ] \\ &= E_x[ E_D[ (g_D(x) -\overline{g}(x))^2 ] ] + E_x[ (\overline{g}(x) - f(x))^2 ] \\ &= \frac{\sum_x\sum_D(g_D(x) -\overline{g}(x))^2}{N_DN_x} + E_x[ (\overline{g}(x) - f(x))^2 ] \end{align*} (Notice how the cross term is equal to zero and ignored here.) Now the decomposition is finished, the first term is variance, the second is bias.

The following example takes the $\sin(\pi x)$ as the target function and generate a universe over $(-1, 1)$, then randomly sample two data points each time as a data set.

Two learning algorithms are compared, the first takes the middle of these two points and draw a horizontal line through it, the second draw a straight line through both. Check which one has higher bias, which one has larger variance, and which one has better performance overall.

R code for the above figures
stepsize = .01
samplestimes = 30
universex = seq(-1, 1, stepsize)
universey = sin(pi * universex)
universexy = data.frame(universex, universey)
samplex1 = sample(universex, samplestimes, repl=T)
sampley1 = sin(pi * samplex1)
samplex2 = sample(universex, samplestimes, repl=T)
sampley2 = sin(pi * samplex2)
samp = data.frame(samplex1, samplex2, sampley1, sampley2)
samp$b = (samp$sampley1 - samp$sampley2) / (samp$samplex1 - samp$samplex2)
samp$a = samp$sampley1 - samp$b * samp$samplex1
samp$h = (samp$sampley1 + samp$sampley2) /2
require(mice)
samp = cc(samp)
head(samp)

# function factory for producing line functions
makeline = function(a, b) {
 a
 b
 function(x) {
  a + b*x
 }
}

# store all these line functions in a list
# gd1 gd2 gd3 ...
linelist = list()
for(i in 1:nrow(samp)) {
 a = samp[i, ]$a
 b = samp[i, ]$b
 linelist[[i]] = makeline(a, b)
}

# produce
# gd1(x) gd2(x) gd3(x)...
gx = sapply(linelist, function(f) f(universex))
gxbar = matrix(rowMeans(gx))
gxbarMat = gxbar[, rep(1, ncol(gx))]
gvar = mean((gx - gxbarMat)^2)
gbias = mean((gxbar - universey)^2)
gbias
gvar
gxvector = as.vector(gx)
xgxdat = data.frame(universex, gxvector)

# now consider the constant function cf, produce
# cfd1(x) cfd2(x)...
cfx = matrix(samp$h, nrow=1)
cfx = cfx[rep(1, length(universex)), ]
cfxbar = matrix(rowMeans(cfx))
cfxbarMat = cfxbar[, rep(1, ncol(cfx))]
cfvar = mean((cfx - cfxbarMat)^2)
cfbias = mean((cfxbar - universey)^2)
cfvar
cfbias
cfxvector = as.vector(cfx)
xcfxdat = data.frame(universex, cfxvector)

require(ggplot2)
# plot line functions
plotcurve = ggplot(universexy, aes(universex, universey)) + geom_line()
for(i in 1:nrow(samp)) {
 plotcurve = plotcurve +
 geom_abline(intercept=samp[i, ]$a, slope=samp[i, ]$b, color="#2B960025")
}
plotcurve = plotcurve + stat_smooth(data=xgxdat,
  aes(universex, gxvector),
  level=.99999999999999)
print(plotcurve)

# plot constant functions
plotcurve = ggplot(universexy, aes(universex, universey)) + geom_line()
for(i in 1:nrow(samp)) {
 plotcurve = plotcurve +
 geom_hline(yintercept=samp[i, ]$h, color="#2B960025")
}
plotcurve = plotcurve + stat_smooth(data=xcfxdat,
 aes(universex, cfxvector),
 level=.999999999999999)
print(plotcurve)

This post is an annotation of the following video from caltech:

0 comments: