Main Page | Class Hierarchy | Class List | File List | Class Members | File Members

svd.cpp

00001 #include "svd.h"
00002 #include "util.h"
00003 #include "linalg.h"
00004 #include <cmath>
00005 
00006 
00007 #define SIGN(a,b) ((b) > 0.0 ? fabs(a) : -fabs(a))
00008 
00009 static double maxarg1, maxarg2;
00010 #define FMAX(a,b) (maxarg1 = (a), maxarg2 = (b), (maxarg1) > (maxarg2) ? (maxarg1) : (maxarg2))
00011 
00012 static int iminarg1, iminarg2;
00013 #define IMIN(a,b) (iminarg1 = (a), iminarg2 = (b), (iminarg1) < (iminarg2) ? (iminarg1) : (iminarg2))
00014 
00015 static double sqrarg;
00016 #define SQR(a) ((sqrarg = (a)) == 0.0 ? 0.0 : sqrarg * sqrarg)
00017 
00018 double pythag(double a, double b) {
00019   // computes (a^2+b^2)^(1/2) without destructive under/over-flow
00020   double absa = fabs(a);
00021   double absb = fabs(b);
00022   if (absa>absb) { 
00023     return absa*sqrt(1.0+SQR(absb/absa)); 
00024   }
00025   else {
00026     return (absb == 0.0 ? 0.0 : absb*sqrt(1.0+SQR(absa/absb)));
00027   }
00028 }
00029 
00030 // SVD-decomposition A = U*W*V', where U,V orthogonal and W diagonal. The matrix U is
00031 // replacing A and notice that V and not the transpose V' is returned. 
00032 // Blackbox-algorithm from NR.
00033 void svdcmp(matrix &A, matrix &DD, matrix &VV) {
00034   int flag,i,its,j,jj,k,l = 0,nm = 0;
00035   double anorm,c,f,g,h,s,scale,x,y,z;
00036 
00037   // Dimensions of A;
00038   int m = A.row();
00039   int n = A.column();
00040   
00041   matrix W(n,n);
00042   matrix V(n,n);
00043 
00044   safevector<double> rv1(n);
00045 
00046   g=scale=anorm=0.0;
00047 
00048   //  cerr << "Start of algorithm\n";
00049   
00050   // householder reduction to bidiagonal form
00051   for (i=0; i<n; i++) {
00052     l=i+1;
00053     rv1[i] = scale*g;
00054     g=s=scale=0.0;
00055     if (i<m) {
00056       for (k=i; k<m; k++)   { scale += fabs(A.get(k,i));   }
00057       if (scale) {
00058         for (k=i; k<m; k++) {
00059           //      cerr << "A( " << k << ", " << i << ")";
00060           A.set(A.get(k,i)/scale, k,i);
00061           s += A.get(k,i)*A.get(k,i);
00062         }
00063         f = A.get(i,i);
00064         g = -SIGN((double)sqrt(s),double(f));
00065         h = f*g-s;
00066         A.set(f-g,i,i);
00067         for (j=l; j<n; j++) {
00068           for (s=0.0, k=i; k<m; k++) {  s += A.get(k,i)*A.get(k,j);  }
00069           f = s/h;
00070           for (k=i; k<m; k++) {  A.set(A.get(k,j)+f*A.get(k,i),k,j);  }
00071         }
00072         for (k=i; k<m; k++) {  A.set(A.get(k,i)*scale,k,i); }
00073       } // if (scale)
00074     } // if (i<=m)
00075 
00076     W.set(scale*g,i,i);
00077     g=s=scale=0.0;
00078     if (i<m && i!=n-1) {
00079       for (k=l; k<n; k++) { scale += fabs(A.get(i,k)); }
00080       if (scale) {
00081         for(k=l; k<n; k++) {
00082         A.set(A.get(i,k)/scale,i,k);
00083         s += A.get(i,k)*A.get(i,k);
00084         }
00085         f = A.get(i,l);
00086         g = -SIGN(sqrt(s),f);
00087         h = f*g-s;
00088         A.set(f-g,i,l);
00089         for (k=l; k<n; k++) {
00090           rv1[k]=A.get(i,k)/h;
00091         }
00092         for (j=l; j<m; j++) {
00093           for (s=0.0, k=l; k<n; k++) { s+= A.get(j,k)*A.get(i,k); }
00094           for (k=l; k<n; k++) { A.set(s*rv1[k]+A.get(j,k),j,k); }
00095         }
00096         for (k=l; k<n; k++) { A.set(A.get(i,k)*scale, i,k); }
00097       }// if (scale)
00098     } // if(i<=m && i!=n)
00099     anorm = FMAX(anorm, (fabs(W.get(i,i))+fabs(rv1[i])));
00100   } //  for (i=1; i<=n; i++) {
00101 
00102 
00103   //  cerr << "End of part 1\n";
00104 
00105   for (i=n-1; i>=0; i--) { // Accumulation of right-hand transformations.
00106     if (i<n-1) {
00107       if (g) {
00108         for (j=l; j<n; j++) { // Double division to avoid possible underflow 
00109           V.set((A.get(i,j)/A.get(i,l))/g ,j,i);
00110         }
00111         for (j=l; j<n; j++) {
00112           for (s=0.0, k=l; k<n; k++) { s +=  A.get(i,k)*V.get(k,j); }
00113           for (k=l; k<n; k++) { V.set(V.get(k,j)+s*V.get(k,i), k,j); }
00114         }
00115       } // if (g)
00116       for (j=l; j<n; j++) { V.set(0,i,j); V.set(0,j,i); }
00117     } // if (i<n)
00118     V.set(1.0,i,i);
00119     g = rv1[i];
00120     l = i;
00121   } // for (i=n; i>=1; i--);
00122 
00123   //  cerr << "End of part 2\n";
00124      
00125   for (i=IMIN(m,n)-1; i>=0; i--) { // Accumulation of left-hand transformations
00126     l = i+1;
00127     g = W.get(i,i);
00128     for (j=l; j<n; j++) { A.set(0.0, i,j); }
00129     if (g) {
00130       g = 1.0/g;
00131       for (j=l; j<n; j++) {
00132         for (s=0.0, k=l; k<m; k++) {  s += A.get(k,i)*A.get(k,j); }
00133         f = (s/A.get(i,i))*g;
00134         for (k=i; k<m; k++) { A.set(A.get(k,j)+f*A.get(k,i), k,j); }
00135       }
00136       for (j=i; j<m; j++) { A.set(g*A.get(j,i), j,i);  }
00137     } // if (g)
00138     else { 
00139       for (j=i; j<m; j++) { A.set(0.0, j,i); }
00140     }
00141     A.set(A.get(i,i)+1, i,i);
00142   } // for (i=IMIN(m,n); i>=1; i--)
00143 
00144   //  cerr << "End of part 3\n";
00145   
00146   for (k=n-1; k>=0; k--) { // Diagonalization of the bidiagonal form; Loop over singular values, and over all allowed iterations
00147     for (its=1; its<=100; its++) {
00148       flag = 1;
00149       for (l=k; l>=0; l--) { // Test for splitting
00150         //      cerr << "\nl=" << l << "   k=" << k <<  endl;
00151         nm = l-1;          // Note that rv1[0] is always zero.
00152         if ((double)(fabs(rv1[l])+anorm) == anorm) {
00153           flag = 0;
00154           break;
00155         }
00156         if ((double)(fabs(W.get(nm,nm))+anorm) == anorm) { break; }
00157       } // for (l=k; l>=1; l--)
00158       if (l<0) l++;
00159       if (flag) { // cancelation of rv1[l], if l>1.
00160         //      cerr << "\nl=" << l << "   k=" << k <<  endl;
00161         c = 0.0;
00162         s = 1.0;
00163         for (i=l; i<=k; i++) {
00164           f = s*rv1[i];
00165           rv1[i] = c*rv1[i];
00166           if ((double)(fabs(f)+anorm) == anorm) { break; }
00167           //      cerr << "W.rows=" << W.row() << " W.cols=" << W.column() << "  i=" << i << endl;
00168           g = W.get(i,i);
00169           h = pythag(f,g);
00170           //      cerr << "W.rows=" << W.row() << " W.cols=" << W.column() << "  i=" << i << endl;
00171           W.set(h,i,i);
00172           h = 1.0/h;
00173           c = g*h;
00174           s = -f*h;
00175           for (j=0; j<m; j++) {
00176             y = A.get(j,nm);
00177             z = A.get(j,i);
00178             A.set(y*c+z*c, j, nm);
00179             A.set(z*c-y*s, j,i);
00180           }
00181         } // for (i=1; i<=k; i++)
00182       } // if (flag)
00183       
00184       //      cerr << "End of part 4\n";
00185   
00186       z = W.get(k,k);
00187       if (l == k) { // Convergence
00188         if (z < 0.0) { // Singular values are made nonnegative
00189           W.set(-z, k,k);
00190           for (j=0; j<n; j++) { V.set(-V.get(j,k), j,k); }
00191         }
00192         break;
00193       } // if (l == k)
00194       
00195       if (its == 100) {
00196         std::cerr << "No convergence in 100 svdcmp iterations\n";
00197         exit(1);
00198       }
00199       
00200       //      cerr << "End of part 5\n";
00201 
00202       x = W.get(l,l);   // Shift from bottom 2-by-2 mirror
00203       nm = k-1;
00204       y = W.get(nm,nm);
00205       g = rv1[nm];
00206       h = rv1[k];
00207         f = ((y-z)*(y+z)+(g-h)*(g+h))/(2.0*h*y);
00208       g = pythag(f,1.0);
00209       f=((x-z)*(x+z)+h*((y/(f+SIGN(g,f)))-h))/x;
00210       c=s=1.0;  // Next QR transformation
00211       for (j=l; j<=nm; j++) {
00212         i = j+1;
00213         g = rv1[i];
00214         y = W.get(i,i);
00215         h = s*g;
00216         g = c*g;
00217         z = pythag(f,h);
00218         rv1[j] = z;
00219         c = f/z;
00220         s = h/z;
00221         f = x*c+g*s;
00222         g = g*c-x*s;
00223         h = y*s;
00224         y *= c;
00225         for (jj=0; jj<n; jj++) {
00226           x = V.get(jj,j);
00227           z = V.get(jj,i);
00228           V.set(x*c+z*s, jj,j);
00229           V.set(z*c-x*s, jj,i);
00230         }
00231         z = pythag(f,h);
00232         W.set(z, j,j);
00233         if  (z) {  // Rotation can be arbitrary if z=0;
00234           z = 1.0/z;
00235           c = f*z;
00236           s = h*z;
00237         }
00238         f = c*g+s*y;
00239         x = c*y-s*g;
00240 
00241         for (jj=0; jj<m; jj++) {
00242           y = A.get(jj,j);
00243           z = A.get(jj,i);
00244           A.set(y*c+z*s, jj,j);
00245           A.set(z*c-y*s, jj,i);
00246         }
00247       } // for (j=1; k<=nm; j++)
00248       rv1[l] = 0.0;
00249       rv1[k] = f;
00250       W.set(x, k,k);   
00251     } // for (its=1; its<=30; its++)
00252   } // for (k=n; k>=1; k--)
00253 
00254   DD  = W;
00255   VV = V;
00256 } // end of function      
00257 
00258 
00259 /*
00260 void redsvd(matrix &A, matrix &D, matrix &V) {
00261 
00262   double threshold = 1e-6;
00263 
00264   svdcmp(A, D, V);
00265 
00266 //   cerr << "dim(V) = " << V.height() << "x" << V.width() << ";  ";
00267 //   cerr << "dim(D) = " << D.height() << "x" << D.width() << ";  ";
00268 //   cerr << "dim(A) = " << A.height() << "x" << A.width() << ";  ";
00269 //   cerr << endl << endl;
00270   
00271   //  cerr << "D_original = " << endl << D << endl;
00272   unsigned N=D.width();
00273   safevector<unsigned> deleteindex;
00274   safevector<double> diagonal;
00275   for (unsigned i=0; i<N; i++) {
00276     if (fabs(D(i,i))>threshold) {
00277       diagonal.push_back(D(i,i));
00278     }
00279     else deleteindex.push_back(i);
00280   }
00281 
00282   //  cerr << "D.size = " << diagonal.size() << endl;
00283   Matrix<double> DD(diagonal.size(), diagonal.size());
00284   for (unsigned i=0; i<diagonal.size(); i++) {
00285     DD(i,i) = diagonal[i];
00286   }
00287   D = DD;
00288 
00289   cerr << "Deleteindex: " << deleteindex << endl;
00290   A.remove_col(deleteindex);
00291   V.remove_col(deleteindex);
00292     cerr << "A(pseudoinv)=" << endl << A << endl;
00293     cerr << "V(pseudoinv)=" << endl << V << endl;
00294 
00295 
00296 //   cerr << "U=" << endl << U << endl;
00297 //   cerr << "V=" << endl << V << endl;
00298 } // redsvd
00299 */
00300 
00301 
00302 matrix pseudoinverse(matrix const& A, double threshold) {
00303 
00304   matrix U=A, D,V;
00305   svdcmp(U,D,V);
00306 
00307   //  cerr << endl << endl << "D=" << endl << D << endl;
00308   matrix iD(D.row(), D.row()); 
00309   for (int i=0; i<D.row(); i++) { // Truncate (near) singular values
00310     if (fabs(D.get(i,i))>threshold)
00311       iD.set((double)1/D.get(i,i), i,i);
00312   }
00313     matrix res = V*iD;
00314     res = res*U.trans();
00315     return res;
00316 }
00317 
00318 
00319 #ifdef TEST
00320 
00321 int main(int argc, char *argv[]) {
00322   try {
00323 
00324     /*
00325     matrix A(4,4);
00326     A.set(1,0,0);   A.set(0,0,1);     A.set(1,0,2);     A.set(0,0,3); 
00327     A.set(0,1,0);   A.set(1,1,1);     A.set(0,1,2);     A.set(1,1,3); 
00328     A.set(2,2,0);   A.set(0,2,1);     A.set(2,2,2);     A.set(0,2,3); 
00329     A.set(0,3,0);   A.set(0,3,1);     A.set(1,3,2);     A.set(1,3,3); 
00330     */
00331 
00332     matrix B(1,4);
00333     B.set(1,0,0);   B.set(0,0,1);     B.set(1,0,2);     B.set(0,0,3); 
00334     matrix A = B.trans();
00335 
00336 
00337     std::cerr.precision(5);
00338     std::cerr << "A = "; A.print();
00339     std::cerr << "-------------------------------------------------------------------" << std::endl;
00340 
00341     matrix U = A;
00342     matrix V(1,1), D(1,1);
00343     svdcmp(U,D,V);
00344 
00345     for (int i=0; i<D.row(); i++) { // Truncate (near) singular values
00346       if (fabs(D.get(i,i))<1e-6)
00347         D.set(0,i,i) ;
00348     }
00349 
00350     matrix newA = U*D*V.trans();
00351     std::cerr << "newA = "; newA.print();
00352     std::cerr << "U'U = "; (U.trans()*U).print();
00353     std::cerr << "V'V = "; (V.trans()*V).print();
00354 
00355     std::cerr << "-------------------------------------------------------------------" << std::endl;
00356 
00357     matrix b(4,1); 
00358     b.set(0.0, 0,0);
00359     b.set(2.0, 1,0);
00360     b.set(1.0, 2,0);
00361     b.set(7.0, 3,0);
00362      
00363     matrix Aplus = pseudoinverse(A);
00364 
00365     std::cerr << "Solving Ax = b = "; (b.trans()).print(); std::cerr << std::endl;
00366     matrix x = Aplus*b;
00367     std::cerr << "Apinv*b = x'  = "; x.trans().print(); std::cerr << std::endl;
00368     std::cerr << "A*x' = "; (A*x).trans().print(); std::cerr << std::endl;
00369 
00370   } catch(std::string const& err) {
00371     std::cerr << err << std::endl;
00372     return 1;
00373   }
00374   return 0;
00375 }
00376 
00377 
00378 #endif//TEST
00379    
00380       
00381 
00382                 

Generated on Tue Feb 14 16:05:52 2006 for estfunc by doxygen 1.3.6