00001 #include "svd.h"
00002 #include "util.h"
00003 #include "linalg.h"
00004 #include <cmath>
00005
00006
00007 #define SIGN(a,b) ((b) > 0.0 ? fabs(a) : -fabs(a))
00008
00009 static double maxarg1, maxarg2;
00010 #define FMAX(a,b) (maxarg1 = (a), maxarg2 = (b), (maxarg1) > (maxarg2) ? (maxarg1) : (maxarg2))
00011
00012 static int iminarg1, iminarg2;
00013 #define IMIN(a,b) (iminarg1 = (a), iminarg2 = (b), (iminarg1) < (iminarg2) ? (iminarg1) : (iminarg2))
00014
00015 static double sqrarg;
00016 #define SQR(a) ((sqrarg = (a)) == 0.0 ? 0.0 : sqrarg * sqrarg)
00017
00018 double pythag(double a, double b) {
00019
00020 double absa = fabs(a);
00021 double absb = fabs(b);
00022 if (absa>absb) {
00023 return absa*sqrt(1.0+SQR(absb/absa));
00024 }
00025 else {
00026 return (absb == 0.0 ? 0.0 : absb*sqrt(1.0+SQR(absa/absb)));
00027 }
00028 }
00029
00030
00031
00032
00033 void svdcmp(matrix &A, matrix &DD, matrix &VV) {
00034 int flag,i,its,j,jj,k,l = 0,nm = 0;
00035 double anorm,c,f,g,h,s,scale,x,y,z;
00036
00037
00038 int m = A.row();
00039 int n = A.column();
00040
00041 matrix W(n,n);
00042 matrix V(n,n);
00043
00044 safevector<double> rv1(n);
00045
00046 g=scale=anorm=0.0;
00047
00048
00049
00050
00051 for (i=0; i<n; i++) {
00052 l=i+1;
00053 rv1[i] = scale*g;
00054 g=s=scale=0.0;
00055 if (i<m) {
00056 for (k=i; k<m; k++) { scale += fabs(A.get(k,i)); }
00057 if (scale) {
00058 for (k=i; k<m; k++) {
00059
00060 A.set(A.get(k,i)/scale, k,i);
00061 s += A.get(k,i)*A.get(k,i);
00062 }
00063 f = A.get(i,i);
00064 g = -SIGN((double)sqrt(s),double(f));
00065 h = f*g-s;
00066 A.set(f-g,i,i);
00067 for (j=l; j<n; j++) {
00068 for (s=0.0, k=i; k<m; k++) { s += A.get(k,i)*A.get(k,j); }
00069 f = s/h;
00070 for (k=i; k<m; k++) { A.set(A.get(k,j)+f*A.get(k,i),k,j); }
00071 }
00072 for (k=i; k<m; k++) { A.set(A.get(k,i)*scale,k,i); }
00073 }
00074 }
00075
00076 W.set(scale*g,i,i);
00077 g=s=scale=0.0;
00078 if (i<m && i!=n-1) {
00079 for (k=l; k<n; k++) { scale += fabs(A.get(i,k)); }
00080 if (scale) {
00081 for(k=l; k<n; k++) {
00082 A.set(A.get(i,k)/scale,i,k);
00083 s += A.get(i,k)*A.get(i,k);
00084 }
00085 f = A.get(i,l);
00086 g = -SIGN(sqrt(s),f);
00087 h = f*g-s;
00088 A.set(f-g,i,l);
00089 for (k=l; k<n; k++) {
00090 rv1[k]=A.get(i,k)/h;
00091 }
00092 for (j=l; j<m; j++) {
00093 for (s=0.0, k=l; k<n; k++) { s+= A.get(j,k)*A.get(i,k); }
00094 for (k=l; k<n; k++) { A.set(s*rv1[k]+A.get(j,k),j,k); }
00095 }
00096 for (k=l; k<n; k++) { A.set(A.get(i,k)*scale, i,k); }
00097 }
00098 }
00099 anorm = FMAX(anorm, (fabs(W.get(i,i))+fabs(rv1[i])));
00100 }
00101
00102
00103
00104
00105 for (i=n-1; i>=0; i--) {
00106 if (i<n-1) {
00107 if (g) {
00108 for (j=l; j<n; j++) {
00109 V.set((A.get(i,j)/A.get(i,l))/g ,j,i);
00110 }
00111 for (j=l; j<n; j++) {
00112 for (s=0.0, k=l; k<n; k++) { s += A.get(i,k)*V.get(k,j); }
00113 for (k=l; k<n; k++) { V.set(V.get(k,j)+s*V.get(k,i), k,j); }
00114 }
00115 }
00116 for (j=l; j<n; j++) { V.set(0,i,j); V.set(0,j,i); }
00117 }
00118 V.set(1.0,i,i);
00119 g = rv1[i];
00120 l = i;
00121 }
00122
00123
00124
00125 for (i=IMIN(m,n)-1; i>=0; i--) {
00126 l = i+1;
00127 g = W.get(i,i);
00128 for (j=l; j<n; j++) { A.set(0.0, i,j); }
00129 if (g) {
00130 g = 1.0/g;
00131 for (j=l; j<n; j++) {
00132 for (s=0.0, k=l; k<m; k++) { s += A.get(k,i)*A.get(k,j); }
00133 f = (s/A.get(i,i))*g;
00134 for (k=i; k<m; k++) { A.set(A.get(k,j)+f*A.get(k,i), k,j); }
00135 }
00136 for (j=i; j<m; j++) { A.set(g*A.get(j,i), j,i); }
00137 }
00138 else {
00139 for (j=i; j<m; j++) { A.set(0.0, j,i); }
00140 }
00141 A.set(A.get(i,i)+1, i,i);
00142 }
00143
00144
00145
00146 for (k=n-1; k>=0; k--) {
00147 for (its=1; its<=100; its++) {
00148 flag = 1;
00149 for (l=k; l>=0; l--) {
00150
00151 nm = l-1;
00152 if ((double)(fabs(rv1[l])+anorm) == anorm) {
00153 flag = 0;
00154 break;
00155 }
00156 if ((double)(fabs(W.get(nm,nm))+anorm) == anorm) { break; }
00157 }
00158 if (l<0) l++;
00159 if (flag) {
00160
00161 c = 0.0;
00162 s = 1.0;
00163 for (i=l; i<=k; i++) {
00164 f = s*rv1[i];
00165 rv1[i] = c*rv1[i];
00166 if ((double)(fabs(f)+anorm) == anorm) { break; }
00167
00168 g = W.get(i,i);
00169 h = pythag(f,g);
00170
00171 W.set(h,i,i);
00172 h = 1.0/h;
00173 c = g*h;
00174 s = -f*h;
00175 for (j=0; j<m; j++) {
00176 y = A.get(j,nm);
00177 z = A.get(j,i);
00178 A.set(y*c+z*c, j, nm);
00179 A.set(z*c-y*s, j,i);
00180 }
00181 }
00182 }
00183
00184
00185
00186 z = W.get(k,k);
00187 if (l == k) {
00188 if (z < 0.0) {
00189 W.set(-z, k,k);
00190 for (j=0; j<n; j++) { V.set(-V.get(j,k), j,k); }
00191 }
00192 break;
00193 }
00194
00195 if (its == 100) {
00196 std::cerr << "No convergence in 100 svdcmp iterations\n";
00197 exit(1);
00198 }
00199
00200
00201
00202 x = W.get(l,l);
00203 nm = k-1;
00204 y = W.get(nm,nm);
00205 g = rv1[nm];
00206 h = rv1[k];
00207 f = ((y-z)*(y+z)+(g-h)*(g+h))/(2.0*h*y);
00208 g = pythag(f,1.0);
00209 f=((x-z)*(x+z)+h*((y/(f+SIGN(g,f)))-h))/x;
00210 c=s=1.0;
00211 for (j=l; j<=nm; j++) {
00212 i = j+1;
00213 g = rv1[i];
00214 y = W.get(i,i);
00215 h = s*g;
00216 g = c*g;
00217 z = pythag(f,h);
00218 rv1[j] = z;
00219 c = f/z;
00220 s = h/z;
00221 f = x*c+g*s;
00222 g = g*c-x*s;
00223 h = y*s;
00224 y *= c;
00225 for (jj=0; jj<n; jj++) {
00226 x = V.get(jj,j);
00227 z = V.get(jj,i);
00228 V.set(x*c+z*s, jj,j);
00229 V.set(z*c-x*s, jj,i);
00230 }
00231 z = pythag(f,h);
00232 W.set(z, j,j);
00233 if (z) {
00234 z = 1.0/z;
00235 c = f*z;
00236 s = h*z;
00237 }
00238 f = c*g+s*y;
00239 x = c*y-s*g;
00240
00241 for (jj=0; jj<m; jj++) {
00242 y = A.get(jj,j);
00243 z = A.get(jj,i);
00244 A.set(y*c+z*s, jj,j);
00245 A.set(z*c-y*s, jj,i);
00246 }
00247 }
00248 rv1[l] = 0.0;
00249 rv1[k] = f;
00250 W.set(x, k,k);
00251 }
00252 }
00253
00254 DD = W;
00255 VV = V;
00256 }
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302 matrix pseudoinverse(matrix const& A, double threshold) {
00303
00304 matrix U=A, D,V;
00305 svdcmp(U,D,V);
00306
00307
00308 matrix iD(D.row(), D.row());
00309 for (int i=0; i<D.row(); i++) {
00310 if (fabs(D.get(i,i))>threshold)
00311 iD.set((double)1/D.get(i,i), i,i);
00312 }
00313 matrix res = V*iD;
00314 res = res*U.trans();
00315 return res;
00316 }
00317
00318
00319 #ifdef TEST
00320
00321 int main(int argc, char *argv[]) {
00322 try {
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332 matrix B(1,4);
00333 B.set(1,0,0); B.set(0,0,1); B.set(1,0,2); B.set(0,0,3);
00334 matrix A = B.trans();
00335
00336
00337 std::cerr.precision(5);
00338 std::cerr << "A = "; A.print();
00339 std::cerr << "-------------------------------------------------------------------" << std::endl;
00340
00341 matrix U = A;
00342 matrix V(1,1), D(1,1);
00343 svdcmp(U,D,V);
00344
00345 for (int i=0; i<D.row(); i++) {
00346 if (fabs(D.get(i,i))<1e-6)
00347 D.set(0,i,i) ;
00348 }
00349
00350 matrix newA = U*D*V.trans();
00351 std::cerr << "newA = "; newA.print();
00352 std::cerr << "U'U = "; (U.trans()*U).print();
00353 std::cerr << "V'V = "; (V.trans()*V).print();
00354
00355 std::cerr << "-------------------------------------------------------------------" << std::endl;
00356
00357 matrix b(4,1);
00358 b.set(0.0, 0,0);
00359 b.set(2.0, 1,0);
00360 b.set(1.0, 2,0);
00361 b.set(7.0, 3,0);
00362
00363 matrix Aplus = pseudoinverse(A);
00364
00365 std::cerr << "Solving Ax = b = "; (b.trans()).print(); std::cerr << std::endl;
00366 matrix x = Aplus*b;
00367 std::cerr << "Apinv*b = x' = "; x.trans().print(); std::cerr << std::endl;
00368 std::cerr << "A*x' = "; (A*x).trans().print(); std::cerr << std::endl;
00369
00370 } catch(std::string const& err) {
00371 std::cerr << err << std::endl;
00372 return 1;
00373 }
00374 return 0;
00375 }
00376
00377
00378 #endif//TEST
00379
00380
00381
00382