Tuesday, April 22, 2014

Sparse matrix should be operated on with suitable data structure

From wikipedia:

When storing and manipulating sparse matrices on a computer, it is beneficial and often necessary to use specialized algorithms and data structures that take advantage of the sparse structure of the matrix. Operations using standard dense-matrix structures and algorithms are relatively slow and consume large amounts of memory when applied to large sparse matrices. Sparse data is by nature easily compressed, and this compression almost always results in significantly less computer data storage usage. Indeed, some very large sparse matrices are infeasible to manipulate using standard dense algorithms.

R has the Matrix package with special data structures for dealing with this. Here is an demonstration in R:

Performance test of sparse matrix in R
n = 9000
k = 100
u = matrix(rnorm(n*k), k)
v = matrix(rnorm(n*k), n)
w = rnorm(n)
W1 = diag(w)
W2 = Diagonal(x=w)
system.time(u %*% W1 %*% v)
system.time(u %*% W2 %*% v)

#////////////////////////////////////////////////////
system.time(u %*% W1 %*% v)
   user  system elapsed
  1.956   0.000   1.956
#////////////////////////////////////////////////////
system.time(u %*% W2 %*% v)
   user  system elapsed
  0.081   0.004   0.086

This brings about an staggering improvement of over 20 times! Let's revisit one previous example: logistic regression implemented with batch gradient descent (in that post we didn't calculate the standard error of coefficients, if we do, glm will beat us by a large margin!), but this time take advantage of sparse matrix:

Logistic regression implementation improved
require(Matrix)
logreg = function(y, x) {
 x = as.matrix(x)
 x = apply(x, 2, scale)
 x = cbind(1, x)
 m = nrow(x)
 n = ncol(x)
 alpha = 2/m

 # b = matrix(rnorm(n))
 # b = matrix(summary(lm(y~x))$coef[, 1])
 b = matrix(rep(0, n))
 v = exp(-x %*% b)
 h = 1 / (1 + v)

 J = -(t(y) %*% log(h) + t(1-y) %*% log(1 -h))
 derivJ = t(x) %*% (h-y)


 niter = 0
 while(1) {
  niter = niter + 1
  newb = b - alpha * derivJ
  v = exp(-x %*% newb)
  h = 1 / (1 + v)
  newJ = -(t(y) %*% log(h) + t(0-y) %*% log(1 -h))
  while((newJ - J) >= 0) {
   print("inner while...")
   # step adjust
   alpha = alpha / 1.15
   newb = b - alpha * derivJ
   v = exp(-x %*% newb)
   h = 1 / (1 + v)
   newJ = -(t(y) %*% log(h) + t(1-y) %*% log(1 -h))
  }
  if(max(abs(b - newb)) < 0.001) {
   break
  }
  b = newb
  J = newJ
  derivJ = t(x) %*% (h-y)
 }
 b
 w = h^2 * v
 # # hessian matrix of cost function
 hess = t(x) %*% Diagonal(x = as.vector(w)) %*% x
 seMat = sqrt(diag(solve(hess)))
 zscore = b / seMat
 cbind(b, zscore)
}

nr = 5000
nc = 5
# set.seed(17)
x = matrix(rnorm(nr*nc, 0, 999), nr)
x = apply(x, 2, scale)
# y = matrix(sample(0:1, nr, repl=T), nr)
h = 1/(1 + exp(-x %*% rnorm(nc)))
y = round(h)
y[1:round(nr/2)] = sample(0:1, round(nr/2), repl=T)


ntests = 300
testglm = function() {
 for(i in 1:ntests) {
  res = summary(glm(y~x, family=binomial))$coef
 }
 print(res)
}

testlogreg = function() {
 for(i in 1:ntests) {
  res = logreg(y, x)
 }
 print(res)
}

print(system.time(testlogreg()))
print(system.time(testglm()))

# benchmark results
#////////////////////////////////////////////////////
print(system.time(testlogreg()))
            [,1]      [,2]
[1,] -0.00639276 -0.205752
[2,] -0.23057674 -7.327096
[3,] -0.05069555 -1.634037
[4,] -0.00793098 -0.254736
[5,]  0.47984767 14.795614
[6,]  0.84068549 23.592259
   user  system elapsed
  8.671   0.000   8.673
#////////////////////////////////////////////////////
print(system.time(testglm()))
               Estimate Std. Error   z value     Pr(>|z|)
(Intercept) -0.00641490  0.0310827 -0.206382  8.36493e-01
x1          -0.23164275  0.0314839 -7.357510  1.87372e-13
x2          -0.05082648  0.0310369 -1.637614  1.01502e-01
x3          -0.00789098  0.0311468 -0.253348  7.99999e-01
x4           0.48154867  0.0324508 14.839341  8.15527e-50
x5           0.84393705  0.0356694 23.659992 9.31454e-124
   user  system elapsed
  7.571   0.000   7.572

Not so bad, sometimes our algorithm can even beat glm.

0 comments: