001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.Arrays;
016import java.util.List;
017
018import org.apache.commons.math3.complex.Complex;
019import org.apache.commons.math3.linear.Array2DRowRealMatrix;
020import org.apache.commons.math3.linear.ArrayRealVector;
021import org.apache.commons.math3.linear.CholeskyDecomposition;
022import org.apache.commons.math3.linear.ConjugateGradient;
023import org.apache.commons.math3.linear.EigenDecomposition;
024import org.apache.commons.math3.linear.LUDecomposition;
025import org.apache.commons.math3.linear.MatrixUtils;
026import org.apache.commons.math3.linear.QRDecomposition;
027import org.apache.commons.math3.linear.RealLinearOperator;
028import org.apache.commons.math3.linear.RealMatrix;
029import org.apache.commons.math3.linear.RealVector;
030import org.apache.commons.math3.linear.SingularValueDecomposition;
031
032
033public class LinearAlgebra {
034
035        private static final int CROSSOVERPOINT = 16; // point at which using slice iterators for inner loop is faster 
036
037        /**
038         * Calculate the tensor dot product over given axes. This is the sum of products of elements selected
039         * from the given axes in each dataset
040         * @param a
041         * @param b
042         * @param axisa axis dimension in a to sum over (can be -ve)
043         * @param axisb axis dimension in b to sum over (can be -ve)
044         * @return tensor dot product
045         */
046        public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int axisa, final int axisb) {
047                // this is slower for summing lengths < ~15
048                final int[] ashape = a.getShapeRef();
049                final int[] bshape = b.getShapeRef();
050                final int arank = ashape.length;
051                final int brank = bshape.length;
052                int aaxis = ShapeUtils.checkAxis(arank, axisa);
053
054                if (ashape[aaxis] < CROSSOVERPOINT) { // faster to use position iteration
055                        return tensorDotProduct(a, b, new int[] {axisa}, new int[] {axisb});
056                }
057                int baxis = ShapeUtils.checkAxis(brank, axisb);
058
059                final boolean[] achoice = new boolean[arank];
060                final boolean[] bchoice = new boolean[brank];
061                Arrays.fill(achoice, true);
062                Arrays.fill(bchoice, true);
063                achoice[aaxis] = false; // flag which axes not to iterate over
064                bchoice[baxis] = false;
065
066                final boolean[] notachoice = new boolean[arank];
067                final boolean[] notbchoice = new boolean[brank];
068                notachoice[aaxis] = true; // flag which axes to iterate over
069                notbchoice[baxis] = true;
070
071                int drank = arank + brank - 2;
072                int[] dshape = new int[drank];
073                int d = 0;
074                for (int i = 0; i < arank; i++) {
075                        if (achoice[i])
076                                dshape[d++] = ashape[i];
077                }
078                for (int i = 0; i < brank; i++) {
079                        if (bchoice[i])
080                                dshape[d++] = bshape[i];
081                }
082                int dtype = DTypeUtils.getBestDType(a.getDType(), b.getDType());
083                @SuppressWarnings("deprecation")
084                Dataset data = DatasetFactory.zeros(dshape, dtype);
085
086                SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice);
087                int l = 0;
088                final int[] apos = ita.getPos();
089                while (ita.hasNext()) {
090                        SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice);
091                        final int[] bpos = itb.getPos();
092                        while (itb.hasNext()) {
093                                SliceIterator itaa = a.getSliceIteratorFromAxes(apos, notachoice);
094                                SliceIterator itba = b.getSliceIteratorFromAxes(bpos, notbchoice);
095                                double sum = 0.0;
096                                double com = 0.0;
097                                while (itaa.hasNext() && itba.hasNext()) {
098                                        final double y = a.getElementDoubleAbs(itaa.index) * b.getElementDoubleAbs(itba.index) - com;
099                                        final double t = sum + y;
100                                        com = (t - sum) - y;
101                                        sum = t;
102                                }
103                                data.setObjectAbs(l++, sum);
104                        }
105                }
106
107                return data;
108        }
109
110        /**
111         * Calculate the tensor dot product over given axes. This is the sum of products of elements selected
112         * from the given axes in each dataset
113         * @param a
114         * @param b
115         * @param axisa axis dimensions in a to sum over (can be -ve)
116         * @param axisb axis dimensions in b to sum over (can be -ve)
117         * @return tensor dot product
118         */
119        public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int[] axisa, final int[] axisb) {
120                if (axisa.length != axisb.length) {
121                        throw new IllegalArgumentException("Numbers of summing axes must be same");
122                }
123                final int[] ashape = a.getShapeRef();
124                final int[] bshape = b.getShapeRef();
125                final int arank = ashape.length;
126                final int brank = bshape.length;
127                final int[] aaxes = new int[axisa.length];
128                final int[] baxes = new int[axisa.length];
129                for (int i = 0; i < axisa.length; i++) {
130                        aaxes[i] = ShapeUtils.checkAxis(arank, axisa[i]);
131                        int n = ShapeUtils.checkAxis(brank, axisb[i]);
132                        baxes[i] = n;
133
134                        if (ashape[aaxes[i]] != bshape[n]) {
135                                throw new IllegalArgumentException("Summing axes do not have matching lengths");
136                        }
137                }
138
139                final boolean[] achoice = new boolean[arank];
140                final boolean[] bchoice = new boolean[brank];
141                Arrays.fill(achoice, true);
142                Arrays.fill(bchoice, true);
143                for (int i = 0; i < aaxes.length; i++) { // flag which axes to iterate over
144                        achoice[aaxes[i]] = false;
145                        bchoice[baxes[i]] = false;
146                }
147
148                int drank = arank + brank - 2*aaxes.length;
149                int[] dshape = new int[drank];
150                int d = 0;
151                for (int i = 0; i < arank; i++) {
152                        if (achoice[i])
153                                dshape[d++] = ashape[i];
154                }
155                for (int i = 0; i < brank; i++) {
156                        if (bchoice[i])
157                                dshape[d++] = bshape[i];
158                }
159                int dtype = DTypeUtils.getBestDType(a.getDType(), b.getDType());
160                @SuppressWarnings("deprecation")
161                Dataset data = DatasetFactory.zeros(dshape, dtype);
162
163                SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice);
164                int l = 0;
165                final int[] apos = ita.getPos();
166                while (ita.hasNext()) {
167                        SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice);
168                        final int[] bpos = itb.getPos();
169                        while (itb.hasNext()) {
170                                double sum = 0.0;
171                                double com = 0.0;
172                                apos[aaxes[aaxes.length - 1]] = -1;
173                                bpos[baxes[aaxes.length - 1]] = -1;
174                                while (true) { // step through summing axes
175                                        int e = aaxes.length - 1;
176                                        for (; e >= 0; e--) {
177                                                int ai = aaxes[e];
178                                                int bi = baxes[e];
179
180                                                apos[ai]++;
181                                                bpos[bi]++;
182                                                if (apos[ai] == ashape[ai]) {
183                                                        apos[ai] = 0;
184                                                        bpos[bi] = 0;
185                                                } else
186                                                        break;
187                                        }
188                                        if (e == -1) break;
189                                        final double y = a.getDouble(apos) * b.getDouble(bpos) - com;
190                                        final double t = sum + y;
191                                        com = (t - sum) - y;
192                                        sum = t;
193                                }
194                                data.setObjectAbs(l++, sum);
195                        }
196                }
197
198                return data;
199        }
200
201        /**
202         * Calculate the dot product of two datasets. When <b>b</b> is a 1D dataset, the sum product over
203         * the last axis of <b>a</b> and <b>b</b> is returned. Where <b>a</b> is also a 1D dataset, a zero-rank dataset
204         * is returned. If <b>b</b> is 2D or higher, its second-to-last axis is used
205         * @param a
206         * @param b
207         * @return dot product
208         */
209        public static Dataset dotProduct(Dataset a, Dataset b) {
210                if (b.getRank() < 2)
211                        return tensorDotProduct(a, b, -1, 0);
212                return tensorDotProduct(a, b, -1, -2);
213        }
214
215        /**
216         * Calculate the outer product of two datasets
217         * @param a
218         * @param b
219         * @return outer product
220         */
221        public static Dataset outerProduct(Dataset a, Dataset b) {
222                int[] as = a.getShapeRef();
223                int[] bs = b.getShapeRef();
224                int rank = as.length + bs.length;
225                int[] shape = new int[rank];
226                for (int i = 0; i < as.length; i++) {
227                        shape[i] = as[i];
228                }
229                for (int i = 0; i < bs.length; i++) {
230                        shape[as.length + i] = bs[i];
231                }
232                int isa = a.getElementsPerItem();
233                int isb = b.getElementsPerItem();
234                if (isa != 1 || isb != 1) {
235                        throw new UnsupportedOperationException("Compound datasets not supported");
236                }
237                @SuppressWarnings("deprecation")
238                Dataset o = DatasetFactory.zeros(shape, DTypeUtils.getBestDType(a.getDType(), b.getDType()));
239
240                IndexIterator ita = a.getIterator();
241                IndexIterator itb = b.getIterator();
242                int j = 0;
243                while (ita.hasNext()) {
244                        double va = a.getElementDoubleAbs(ita.index);
245                        while (itb.hasNext()) {
246                                o.setObjectAbs(j++, va * b.getElementDoubleAbs(itb.index));
247                        }
248                        itb.reset();
249                }
250                return o;
251        }
252
253        /**
254         * Calculate the cross product of two datasets. Datasets must be broadcastable and
255         * possess last dimensions of length 2 or 3
256         * @param a
257         * @param b
258         * @return cross product
259         */
260        public static Dataset crossProduct(Dataset a, Dataset b) {
261                return crossProduct(a, b, -1, -1, -1);
262        }
263
264        /**
265         * Calculate the cross product of two datasets. Datasets must be broadcastable and
266         * possess dimensions of length 2 or 3. The axis parameters can be negative to indicate
267         * dimensions from the end of their shapes
268         * @param a
269         * @param b
270         * @param axisA dimension to be used a vector (must have length of 2 or 3)
271         * @param axisB dimension to be used a vector (must have length of 2 or 3)
272         * @param axisC dimension to assign as cross-product
273         * @return cross product
274         */
275        public static Dataset crossProduct(Dataset a, Dataset b, int axisA, int axisB, int axisC) {
276                final int rankA = a.getRank();
277                final int rankB = b.getRank();
278                if (rankA == 0 || rankB == 0) {
279                        throw new IllegalArgumentException("Datasets must have one or more dimensions");
280                }
281                axisA = a.checkAxis(axisA);
282                axisB = b.checkAxis(axisB);
283
284                final int[] shapeA = a.getShape();
285                final int[] shapeB = b.getShape();
286                int la = shapeA[axisA];
287                int lb = shapeB[axisB];
288                if (Math.min(la,  lb) < 2 || Math.max(la, lb) > 3) {
289                        throw new IllegalArgumentException("Chosen dimension of A & B must be 2 or 3");
290                }
291
292                if (Math.max(la,  lb) == 2) {
293                        return crossProduct2D(a, b, axisA, axisB);
294                }
295
296                return crossProduct3D(a, b, axisA, axisB, axisC);
297        }
298
299        private static int[] removeAxisFromShape(int[] shape, int axis) {
300                int[] s = new int[shape.length - 1];
301                int i = 0;
302                int j = 0;
303                while (i < axis) {
304                        s[j++] = shape[i++];
305                }
306                i++;
307                while (i < shape.length) {
308                        s[j++] = shape[i++];
309                }
310                return s;
311        }
312
313        // assume axes is in increasing order
314        private static int[] removeAxesFromShape(int[] shape, int... axes) {
315                int n = axes.length;
316                int[] s = new int[shape.length - n];
317                int i = 0;
318                int j = 0;
319                for (int k = 0; k < n; k++) {
320                        int a = axes[k];
321                        while (i < a) {
322                                s[j++] = shape[i++];
323                        }
324                        i++;
325                }
326                while (i < shape.length) {
327                        s[j++] = shape[i++];
328                }
329                return s;
330        }
331
332        private static int[] addAxisToShape(int[] shape, int axis, int length) {
333                int[] s = new int[shape.length + 1];
334                int i = 0;
335                int j = 0;
336                while (i < axis) {
337                        s[j++] = shape[i++];
338                }
339                s[j++] = length;
340                while (i < shape.length) {
341                        s[j++] = shape[i++];
342                }
343                return s;
344        }
345
346        private static Dataset crossProduct2D(Dataset a, Dataset b, int axisA, int axisB) {
347                // need to broadcast and omit given axes
348                int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA);
349                int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB);
350
351                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB);
352
353                int[] maxShape = fullShapes.get(0);
354                @SuppressWarnings("deprecation")
355                Dataset c = DatasetFactory.zeros(maxShape, DTypeUtils.getBestDType(a.getDType(), b.getDType()));
356
357                PositionIterator ita = a.getPositionIterator(axisA);
358                PositionIterator itb = b.getPositionIterator(axisB);
359                IndexIterator itc = c.getIterator();
360
361                final int[] pa = ita.getPos();
362                final int[] pb = itb.getPos();
363                while (itc.hasNext()) {
364                        if (!ita.hasNext()) // TODO use broadcasting...
365                                ita.reset();
366                        if (!itb.hasNext())
367                                itb.reset();
368                        pa[axisA] = 0;
369                        pb[axisB] = 1;
370                        double cv = a.getDouble(pa) * b.getDouble(pb);
371                        pa[axisA] = 1;
372                        pb[axisB] = 0;
373                        cv -= a.getDouble(pa) * b.getDouble(pb);
374
375                        c.setObjectAbs(itc.index, cv);
376                }
377                return c;
378        }
379
380        private static Dataset crossProduct3D(Dataset a, Dataset b, int axisA, int axisB, int axisC) {
381                int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA);
382                int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB);
383
384                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB);
385
386                int[] maxShape = fullShapes.get(0);
387                int rankC = maxShape.length + 1;
388                axisC = ShapeUtils.checkAxis(rankC, axisC);
389                maxShape = addAxisToShape(maxShape, axisC, 3);
390                @SuppressWarnings("deprecation")
391                Dataset c = DatasetFactory.zeros(maxShape, DTypeUtils.getBestDType(a.getDType(), b.getDType()));
392
393                PositionIterator ita = a.getPositionIterator(axisA);
394                PositionIterator itb = b.getPositionIterator(axisB);
395                PositionIterator itc = c.getPositionIterator(axisC);
396
397                final int[] pa = ita.getPos();
398                final int[] pb = itb.getPos();
399                final int[] pc = itc.getPos();
400                final int la = a.getShapeRef()[axisA];
401                final int lb = b.getShapeRef()[axisB];
402
403                if (la == 2) {
404                        while (itc.hasNext()) {
405                                if (!ita.hasNext()) // TODO use broadcasting...
406                                        ita.reset();
407                                if (!itb.hasNext())
408                                        itb.reset();
409                                double cv;
410                                pa[axisA] = 1;
411                                pb[axisB] = 2;
412                                cv = a.getDouble(pa) * b.getDouble(pb);
413                                pc[axisC] = 0;
414                                c.set(cv, pc);
415
416                                pa[axisA] = 0;
417                                pb[axisB] = 2;
418                                cv = -a.getDouble(pa) * b.getDouble(pb);
419                                pc[axisC] = 1;
420                                c.set(cv, pc);
421
422                                pa[axisA] = 0;
423                                pb[axisB] = 1;
424                                cv = a.getDouble(pa) * b.getDouble(pb);
425                                pa[axisA] = 1;
426                                pb[axisB] = 0;
427                                cv -= a.getDouble(pa) * b.getDouble(pb);
428                                pc[axisC] = 2;
429                                c.set(cv, pc);
430                        }
431                } else if (lb == 2) {
432                        while (itc.hasNext()) {
433                                if (!ita.hasNext()) // TODO use broadcasting...
434                                        ita.reset();
435                                if (!itb.hasNext())
436                                        itb.reset();
437                                double cv;
438                                pa[axisA] = 2;
439                                pb[axisB] = 1;
440                                cv = -a.getDouble(pa) * b.getDouble(pb);
441                                pc[axisC] = 0;
442                                c.set(cv, pc);
443
444                                pa[axisA] = 2;
445                                pb[axisB] = 0;
446                                cv = a.getDouble(pa) * b.getDouble(pb);
447                                pc[axisC] = 1;
448                                c.set(cv, pc);
449
450                                pa[axisA] = 0;
451                                pb[axisB] = 1;
452                                cv = a.getDouble(pa) * b.getDouble(pb);
453                                pa[axisA] = 1;
454                                pb[axisB] = 0;
455                                cv -= a.getDouble(pa) * b.getDouble(pb);
456                                pc[axisC] = 2;
457                                c.set(cv, pc);
458                        }
459                        
460                } else {
461                        while (itc.hasNext()) {
462                                if (!ita.hasNext()) // TODO use broadcasting...
463                                        ita.reset();
464                                if (!itb.hasNext())
465                                        itb.reset();
466                                double cv;
467                                pa[axisA] = 1;
468                                pb[axisB] = 2;
469                                cv = a.getDouble(pa) * b.getDouble(pb);
470                                pa[axisA] = 2;
471                                pb[axisB] = 1;
472                                cv -= a.getDouble(pa) * b.getDouble(pb);
473                                pc[axisC] = 0;
474                                c.set(cv, pc);
475
476                                pa[axisA] = 2;
477                                pb[axisB] = 0;
478                                cv = a.getDouble(pa) * b.getDouble(pb);
479                                pa[axisA] = 0;
480                                pb[axisB] = 2;
481                                cv -= a.getDouble(pa) * b.getDouble(pb);
482                                pc[axisC] = 1;
483                                c.set(cv, pc);
484
485                                pa[axisA] = 0;
486                                pb[axisB] = 1;
487                                cv = a.getDouble(pa) * b.getDouble(pb);
488                                pa[axisA] = 1;
489                                pb[axisB] = 0;
490                                cv -= a.getDouble(pa) * b.getDouble(pb);
491                                pc[axisC] = 2;
492                                c.set(cv, pc);
493                        }
494                }
495                return c;
496        }
497
498        /**
499         * Raise dataset to given power by matrix multiplication
500         * @param a
501         * @param n power
502         * @return a ** n
503         */
504        public static Dataset power(Dataset a, int n) {
505                if (n < 0) {
506                        LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
507                        return createDataset(lud.getSolver().getInverse().power(-n));
508                }
509                Dataset p = createDataset(createRealMatrix(a).power(n));
510                if (!a.hasFloatingPointElements())
511                        return p.cast(a.getDType());
512                return p;
513        }
514
515        /**
516         * Create the Kronecker product as defined by 
517         * kron[k0,...,kN] = a[i0,...,iN] * b[j0,...,jN]
518         * where kn = sn * in + jn for n = 0...N and s is shape of b
519         * @param a
520         * @param b
521         * @return Kronecker product of a and b
522         */
523        public static Dataset kroneckerProduct(Dataset a, Dataset b) {
524                if (a.getElementsPerItem() != 1 || b.getElementsPerItem() != 1) {
525                        throw new UnsupportedOperationException("Compound datasets (including complex ones) are not currently supported");
526                }
527                int ar = a.getRank();
528                int br = b.getRank();
529                int[] aShape;
530                int[] bShape;
531                aShape = a.getShapeRef();
532                bShape = b.getShapeRef();
533                int r = ar;
534                // pre-pad if ranks are not same
535                if (ar < br) {
536                        r = br;
537                        int[] shape = new int[br];
538                        int j = 0;
539                        for (int i = ar; i < br; i++) {
540                                shape[j++] = 1;
541                        }
542                        int i = 0;
543                        while (j < br) {
544                                shape[j++] = aShape[i++];
545                        }
546                        a = a.reshape(shape);
547                        aShape = shape;
548                } else if (ar > br) {
549                        int[] shape = new int[ar];
550                        int j = 0;
551                        for (int i = br; i < ar; i++) {
552                                shape[j++] = 1;
553                        }
554                        int i = 0;
555                        while (j < ar) {
556                                shape[j++] = bShape[i++];
557                        }
558                        b = b.reshape(shape);
559                        bShape = shape;
560                }
561
562                int[] nShape = new int[r];
563                for (int i = 0; i < r; i++) {
564                        nShape[i] = aShape[i] * bShape[i];
565                }
566                @SuppressWarnings("deprecation")
567                Dataset kron = DatasetFactory.zeros(nShape, DTypeUtils.getBestDType(a.getDType(), b.getDType()));
568                IndexIterator ita = a.getIterator(true);
569                IndexIterator itb = b.getIterator(true);
570                int[] pa = ita.getPos();
571                int[] pb = itb.getPos();
572                int[] off = new int[1];
573                int[] stride = AbstractDataset.createStrides(1, nShape, null, 0, off);
574                if (kron.getDType() == Dataset.INT64) {
575                        while (ita.hasNext()) {
576                                long av = a.getElementLongAbs(ita.index);
577
578                                int ka = 0; 
579                                for (int i = 0; i < r; i++) {
580                                        ka += stride[i] * bShape[i] * pa[i];
581                                }
582                                itb.reset();
583                                while (itb.hasNext()) {
584                                        long bv = b.getElementLongAbs(itb.index);
585                                        int kb = ka;
586                                        for (int i = 0; i < r; i++) {
587                                                kb += stride[i] * pb[i];
588                                        }
589                                        kron.setObjectAbs(kb, av * bv);
590                                }
591                        }
592                } else {
593                        while (ita.hasNext()) {
594                                double av = a.getElementDoubleAbs(ita.index);
595
596                                int ka = 0; 
597                                for (int i = 0; i < r; i++) {
598                                        ka += stride[i] * bShape[i] * pa[i];
599                                }
600                                itb.reset();
601                                while (itb.hasNext()) {
602                                        double bv = b.getElementLongAbs(itb.index);
603                                        int kb = ka;
604                                        for (int i = 0; i < r; i++) {
605                                                kb += stride[i] * pb[i];
606                                        }
607                                        kron.setObjectAbs(kb, av * bv);
608                                }
609                        }
610                }
611
612                return kron;
613        }
614
615        /**
616         * Calculate trace of dataset - sum of values over 1st axis and 2nd axis
617         * @param a
618         * @return trace of dataset
619         */
620        public static Dataset trace(Dataset a) {
621                return trace(a, 0, 0, 1);
622        }
623
624        /**
625         * Calculate trace of dataset - sum of values over axis1 and axis2 where axis2 is offset
626         * @param a
627         * @param offset
628         * @param axis1
629         * @param axis2
630         * @return trace of dataset
631         */
632        public static Dataset trace(Dataset a, int offset, int axis1, int axis2) {
633                int[] shape = a.getShapeRef();
634                int[] axes = new int[] { a.checkAxis(axis1), a.checkAxis(axis2) };
635                Arrays.sort(axes);
636                int is = a.getElementsPerItem();
637                @SuppressWarnings("deprecation")
638                Dataset trace = DatasetFactory.zeros(is, removeAxesFromShape(shape, axes), a.getDType());
639
640                int am = axes[0];
641                int mmax = shape[am];
642                int an = axes[1];
643                int nmax = shape[an];
644                PositionIterator it = new PositionIterator(shape, axes);
645                int[] pos = it.getPos();
646                int i = 0;
647                int mmin;
648                int nmin;
649                if (offset >= 0) {
650                        mmin = 0;
651                        nmin = offset;
652                } else {
653                        mmin = -offset;
654                        nmin = 0;
655                }
656                if (is == 1) {
657                        if (a.getDType() == Dataset.INT64) {
658                                while (it.hasNext()) {
659                                        int m = mmin;
660                                        int n = nmin;
661                                        long s = 0;
662                                        while (m < mmax && n < nmax) {
663                                                pos[am] = m++;
664                                                pos[an] = n++;
665                                                s += a.getLong(pos);
666                                        }
667                                        trace.setObjectAbs(i++, s);
668                                }
669                        } else {
670                                while (it.hasNext()) {
671                                        int m = mmin;
672                                        int n = nmin;
673                                        double s = 0;
674                                        while (m < mmax && n < nmax) {
675                                                pos[am] = m++;
676                                                pos[an] = n++;
677                                                s += a.getDouble(pos);
678                                        }
679                                        trace.setObjectAbs(i++, s);
680                                }
681                        }
682                } else {
683                        AbstractCompoundDataset ca = (AbstractCompoundDataset) a;
684                        if (ca instanceof CompoundLongDataset) {
685                                long[] t = new long[is];
686                                long[] s = new long[is];
687                                while (it.hasNext()) {
688                                        int m = mmin;
689                                        int n = nmin;
690                                        Arrays.fill(s, 0);
691                                        while (m < mmax && n < nmax) {
692                                                pos[am] = m++;
693                                                pos[an] = n++;
694                                                ((CompoundLongDataset)ca).getAbs(ca.get1DIndex(pos), t);
695                                                for (int k = 0; k < is; k++) {
696                                                        s[k] += t[k];
697                                                }
698                                        }
699                                        trace.setObjectAbs(i++, s);
700                                }
701                        } else {
702                                double[] t = new double[is];
703                                double[] s = new double[is];
704                                while (it.hasNext()) {
705                                        int m = mmin;
706                                        int n = nmin;
707                                        Arrays.fill(s, 0);
708                                        while (m < mmax && n < nmax) {
709                                                pos[am] = m++;
710                                                pos[an] = n++;
711                                                ca.getDoubleArray(t, pos);
712                                                for (int k = 0; k < is; k++) {
713                                                        s[k] += t[k];
714                                                }
715                                        }
716                                        trace.setObjectAbs(i++, s);
717                                }
718                        }
719                }
720
721                return trace;
722        }
723
724        /**
725         * Order value for norm
726         */
727        public enum NormOrder {
728                /**
729                 * 2-norm for vectors and Frobenius for matrices
730                 */
731                DEFAULT,
732                /**
733                 * Frobenius (not allowed for vectors)
734                 */
735                FROBENIUS,
736                /**
737                 * Zero-order (not allowed for matrices)
738                 */
739                ZERO,
740                /**
741                 * Positive infinity
742                 */
743                POS_INFINITY,
744                /**
745                 * Negative infinity
746                 */
747                NEG_INFINITY;
748        }
749
750        /**
751         * @param a
752         * @return norm of dataset
753         */
754        public static double norm(Dataset a) {
755                return norm(a, NormOrder.DEFAULT);
756        }
757
758        /**
759         * @param a
760         * @param order
761         * @return norm of dataset
762         */
763        public static double norm(Dataset a, NormOrder order) {
764                int r = a.getRank();
765                if (r == 1) {
766                        return vectorNorm(a, order);
767                } else if (r == 2) {
768                        return matrixNorm(a, order);
769                }
770                throw new IllegalArgumentException("Rank of dataset must be one or two");
771        }
772
773        private static double vectorNorm(Dataset a, NormOrder order) {
774                double n;
775                IndexIterator it;
776                switch (order) {
777                case FROBENIUS:
778                        throw new IllegalArgumentException("Not allowed for vectors");
779                case NEG_INFINITY:
780                case POS_INFINITY:
781                        it = a.getIterator();
782                        if (order == NormOrder.POS_INFINITY) {
783                                n = Double.NEGATIVE_INFINITY;
784                                if (a.isComplex()) {
785                                        while (it.hasNext()) {
786                                                double v = ((Complex) a.getObjectAbs(it.index)).abs();
787                                                n = Math.max(n, v);
788                                        }
789                                } else {
790                                        while (it.hasNext()) {
791                                                double v = Math.abs(a.getElementDoubleAbs(it.index));
792                                                n = Math.max(n, v);
793                                        }
794                                }
795                        } else {
796                                n = Double.POSITIVE_INFINITY;
797                                if (a.isComplex()) {
798                                        while (it.hasNext()) {
799                                                double v = ((Complex) a.getObjectAbs(it.index)).abs();
800                                                n = Math.min(n, v);
801                                        }
802                                } else {
803                                        while (it.hasNext()) {
804                                                double v = Math.abs(a.getElementDoubleAbs(it.index));
805                                                n = Math.min(n, v);
806                                        }
807                                }
808                        }
809                        break;
810                case ZERO:
811                        it = a.getIterator();
812                        n = 0;
813                        if (a.isComplex()) {
814                                while (it.hasNext()) {
815                                        if (!((Complex) a.getObjectAbs(it.index)).equals(Complex.ZERO))
816                                                n++;
817                                }
818                        } else {
819                                while (it.hasNext()) {
820                                        if (a.getElementBooleanAbs(it.index))
821                                                n++;
822                                }
823                        }
824                        
825                        break;
826                default:
827                        n = vectorNorm(a, 2);
828                        break;
829                }
830                return n;
831        }
832
833        private static double matrixNorm(Dataset a, NormOrder order) {
834                double n;
835                IndexIterator it;
836                switch (order) {
837                case NEG_INFINITY:
838                case POS_INFINITY:
839                        n = maxMinMatrixNorm(a, 1, order == NormOrder.POS_INFINITY);
840                        break;
841                case ZERO:
842                        throw new IllegalArgumentException("Not allowed for matrices");
843                default:
844                case FROBENIUS:
845                        it = a.getIterator();
846                        n = 0;
847                        if (a.isComplex()) {
848                                while (it.hasNext()) {
849                                        double v = ((Complex) a.getObjectAbs(it.index)).abs();
850                                        n += v*v;
851                                }
852                        } else {
853                                while (it.hasNext()) {
854                                        double v = a.getElementDoubleAbs(it.index);
855                                        n += v*v;
856                                }
857                        }
858                        n = Math.sqrt(n);
859                        break;
860                }
861                return n;
862        }
863
864        /**
865         * @param a
866         * @param p
867         * @return p-norm of dataset
868         */
869        public static double norm(Dataset a, final double p) {
870                if (p == 0) {
871                        return norm(a, NormOrder.ZERO);
872                }
873                int r = a.getRank();
874                if (r == 1) {
875                        return vectorNorm(a, p);
876                } else if (r == 2) {
877                        return matrixNorm(a, p);
878                }
879                throw new IllegalArgumentException("Rank of dataset must be one or two");
880        }
881
882        private static double vectorNorm(Dataset a, final double p) {
883                IndexIterator it = a.getIterator();
884                double n = 0;
885                if (a.isComplex()) {
886                        while (it.hasNext()) {
887                                double v = ((Complex) a.getObjectAbs(it.index)).abs();
888                                if (p == 2) {
889                                        v *= v;
890                                } else if (p != 1) {
891                                        v = Math.pow(v, p);
892                                }
893                                n += v;
894                        }
895                } else {
896                        while (it.hasNext()) {
897                                double v = a.getElementDoubleAbs(it.index);
898                                if (p == 1) {
899                                        v = Math.abs(v);
900                                } else if (p == 2) {
901                                        v *= v;
902                                } else {
903                                        v = Math.pow(Math.abs(v), p);
904                                }
905                                n += v;
906                        }
907                }
908                return Math.pow(n, 1./p);
909        }
910
911        private static double matrixNorm(Dataset a, final double p) {
912                double n;
913                if (Math.abs(p) == 1) {
914                        n = maxMinMatrixNorm(a, 0, p > 0);
915                } else if (Math.abs(p) == 2) {
916                        double[] s = calcSingularValues(a);
917                        n = p > 0 ? s[0] : s[s.length - 1];
918                } else {
919                        throw new IllegalArgumentException("Order not allowed");
920                }
921
922                return n;
923        }
924
925        private static double maxMinMatrixNorm(Dataset a, int d, boolean max) {
926                double n;
927                IndexIterator it;
928                int[] pos;
929                int l;
930                it = a.getPositionIterator(d);
931                pos = it.getPos();
932                l = a.getShapeRef()[d];
933                if (max) {
934                        n = Double.NEGATIVE_INFINITY;
935                        if (a.isComplex()) {
936                                while (it.hasNext()) {
937                                        double v = ((Complex) a.getObject(pos)).abs();
938                                        for (int i = 1; i < l; i++) {
939                                                pos[d] = i;
940                                                v += ((Complex) a.getObject(pos)).abs();
941                                        }
942                                        pos[d] = 0;
943                                        n = Math.max(n, v);
944                                }
945                        } else {
946                                while (it.hasNext()) {
947                                        double v = Math.abs(a.getDouble(pos));
948                                        for (int i = 1; i < l; i++) {
949                                                pos[d] = i;
950                                                v += Math.abs(a.getDouble(pos));
951                                        }
952                                        pos[d] = 0;
953                                        n = Math.max(n, v);
954                                }
955                        }
956                } else {
957                        n = Double.POSITIVE_INFINITY;
958                        if (a.isComplex()) {
959                                while (it.hasNext()) {
960                                        double v = ((Complex) a.getObject(pos)).abs();
961                                        for (int i = 1; i < l; i++) {
962                                                pos[d] = i;
963                                                v += ((Complex) a.getObject(pos)).abs();
964                                        }
965                                        pos[d] = 0;
966                                        n = Math.min(n, v);
967                                }
968                        } else {
969                                while (it.hasNext()) {
970                                        double v = Math.abs(a.getDouble(pos));
971                                        for (int i = 1; i < l; i++) {
972                                                pos[d] = i;
973                                                v += Math.abs(a.getDouble(pos));
974                                        }
975                                        pos[d] = 0;
976                                        n = Math.min(n, v);
977                                }
978                        }
979                }
980                return n;
981        }
982
983        /**
984         * @param a
985         * @return array of singular values
986         */
987        public static double[] calcSingularValues(Dataset a) {
988                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
989                return svd.getSingularValues();
990        }
991
992
993        /**
994         * Calculate singular value decomposition A = U S V^T
995         * @param a
996         * @return array of U - orthogonal matrix, s - singular values vector, V - orthogonal matrix
997         */
998        public static Dataset[] calcSingularValueDecomposition(Dataset a) {
999                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1000                return new Dataset[] {createDataset(svd.getU()), DatasetFactory.createFromObject(svd.getSingularValues()),
1001                                createDataset(svd.getV())};
1002        }
1003
1004        /**
1005         * Calculate (Moore-Penrose) pseudo-inverse
1006         * @param a
1007         * @return pseudo-inverse
1008         */
1009        public static Dataset calcPseudoInverse(Dataset a) {
1010                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1011                return createDataset(svd.getSolver().getInverse());
1012        }
1013
1014        /**
1015         * Calculate matrix rank by singular value decomposition method
1016         * @param a
1017         * @return effective numerical rank of matrix
1018         */
1019        public static int calcMatrixRank(Dataset a) {
1020                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1021                return svd.getRank();
1022        }
1023
1024        /**
1025         * Calculate condition number of matrix by singular value decomposition method
1026         * @param a
1027         * @return condition number
1028         */
1029        public static double calcConditionNumber(Dataset a) {
1030                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1031                return svd.getConditionNumber();
1032        }
1033
1034        /**
1035         * @param a
1036         * @return determinant of dataset
1037         */
1038        public static double calcDeterminant(Dataset a) {
1039                EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a));
1040                return evd.getDeterminant();
1041        }
1042
1043        /**
1044         * @param a
1045         * @return dataset of eigenvalues (can be double or complex double)
1046         */
1047        public static Dataset calcEigenvalues(Dataset a) {
1048                EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a));
1049                double[] rev = evd.getRealEigenvalues();
1050
1051                if (evd.hasComplexEigenvalues()) {
1052                        double[] iev = evd.getImagEigenvalues();
1053                        return DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev);
1054                }
1055                return DatasetFactory.createFromObject(rev);
1056        }
1057
1058        /**
1059         * Calculate eigen-decomposition A = V D V^T
1060         * @param a
1061         * @return array of D eigenvalues (can be double or complex double) and V eigenvectors
1062         */
1063        public static Dataset[] calcEigenDecomposition(Dataset a) {
1064                EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a));
1065                Dataset[] results = new Dataset[2];
1066
1067                double[] rev = evd.getRealEigenvalues();
1068                if (evd.hasComplexEigenvalues()) {
1069                        double[] iev = evd.getImagEigenvalues();
1070                        results[0] = DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev);
1071                } else {
1072                        results[0] = DatasetFactory.createFromObject(rev);
1073                }
1074                results[1] = createDataset(evd.getV());
1075                return results;
1076        }
1077
1078        /**
1079         * Calculate QR decomposition A = Q R
1080         * @param a
1081         * @return array of Q and R
1082         */
1083        public static Dataset[] calcQRDecomposition(Dataset a) {
1084                QRDecomposition qrd = new QRDecomposition(createRealMatrix(a));
1085                return new Dataset[] {createDataset(qrd.getQT()).getTransposedView(), createDataset(qrd.getR())};
1086        }
1087
1088        /**
1089         * Calculate LU decomposition A = P^-1 L U
1090         * @param a
1091         * @return array of L, U and P
1092         */
1093        public static Dataset[] calcLUDecomposition(Dataset a) {
1094                LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
1095                return new Dataset[] {createDataset(lud.getL()), createDataset(lud.getU()),
1096                                createDataset(lud.getP())};
1097        }
1098
1099        /**
1100         * Calculate inverse of square dataset
1101         * @param a
1102         * @return inverse
1103         */
1104        public static Dataset calcInverse(Dataset a) {
1105                LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
1106                return createDataset(lud.getSolver().getInverse());
1107        }
1108
1109        /**
1110         * Solve linear matrix equation A x = v
1111         * @param a
1112         * @param v
1113         * @return x
1114         */
1115        public static Dataset solve(Dataset a, Dataset v) {
1116                LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
1117                if (v.getRank() == 1) {
1118                        RealVector x = createRealVector(v);
1119                        return createDataset(lud.getSolver().solve(x));
1120                }
1121                RealMatrix x = createRealMatrix(v);
1122                return createDataset(lud.getSolver().solve(x));
1123        }
1124
1125        
1126        /**
1127         * Solve least squares matrix equation A x = v by SVD
1128         * @param a
1129         * @param v
1130         * @return x
1131         */
1132        public static Dataset solveSVD(Dataset a, Dataset v) {
1133                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1134                if (v.getRank() == 1) {
1135                        RealVector x = createRealVector(v);
1136                        return createDataset(svd.getSolver().solve(x));
1137                }
1138                RealMatrix x = createRealMatrix(v);
1139                return createDataset(svd.getSolver().solve(x));
1140        }
1141        
1142        /**
1143         * Calculate Cholesky decomposition A = L L^T
1144         * @param a
1145         * @return L
1146         */
1147        public static Dataset calcCholeskyDecomposition(Dataset a) {
1148                CholeskyDecomposition cd = new CholeskyDecomposition(createRealMatrix(a));
1149                return createDataset(cd.getL());
1150        }
1151
1152        /**
1153         * Calculation A x = v by conjugate gradient method with the stopping criterion being
1154         * that the estimated residual r = v - A x satisfies ||r|| < ||v|| with maximum of 100 iterations
1155         * @param a
1156         * @param v
1157         * @return solution of A^-1 v by conjugate gradient method
1158         */
1159        public static Dataset calcConjugateGradient(Dataset a, Dataset v) {
1160                return calcConjugateGradient(a, v, 100, 1);
1161        }
1162
1163        /**
1164         * Calculation A x = v by conjugate gradient method with the stopping criterion being
1165         * that the estimated residual r = v - A x satisfies ||r|| < delta ||v||
1166         * @param a
1167         * @param v
1168         * @param maxIterations
1169         * @param delta parameter used by stopping criterion
1170         * @return solution of A^-1 v by conjugate gradient method
1171         */
1172        public static Dataset calcConjugateGradient(Dataset a, Dataset v, int maxIterations, double delta) {
1173                ConjugateGradient cg = new ConjugateGradient(maxIterations, delta, false);
1174                return createDataset(cg.solve((RealLinearOperator) createRealMatrix(a), createRealVector(v)));
1175        }
1176
1177        private static RealMatrix createRealMatrix(Dataset a) {
1178                if (a.getRank() != 2) {
1179                        throw new IllegalArgumentException("Dataset must be rank 2");
1180                }
1181                int[] shape = a.getShapeRef();
1182                IndexIterator it = a.getIterator(true);
1183                int[] pos = it.getPos();
1184                RealMatrix m = MatrixUtils.createRealMatrix(shape[0], shape[1]);
1185                while (it.hasNext()) {
1186                        m.setEntry(pos[0], pos[1], a.getElementDoubleAbs(it.index));
1187                }
1188                return m;
1189        }
1190
1191        private static RealVector createRealVector(Dataset a) {
1192                if (a.getRank() != 1) {
1193                        throw new IllegalArgumentException("Dataset must be rank 1");
1194                }
1195                int size = a.getSize();
1196                IndexIterator it = a.getIterator(true);
1197                int[] pos = it.getPos();
1198                RealVector m = new ArrayRealVector(size);
1199                while (it.hasNext()) {
1200                        m.setEntry(pos[0], a.getElementDoubleAbs(it.index));
1201                }
1202                return m;
1203        }
1204
1205        private static Dataset createDataset(RealVector v) {
1206                DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, v.getDimension());
1207                int size = r.getSize();
1208                if (v instanceof ArrayRealVector) {
1209                        double[] data = ((ArrayRealVector) v).getDataRef();
1210                        for (int i = 0; i < size; i++) {
1211                                r.setAbs(i, data[i]);
1212                        }
1213                } else {
1214                        for (int i = 0; i < size; i++) {
1215                                r.setAbs(i, v.getEntry(i));
1216                        }
1217                }
1218                return r;
1219        }
1220
1221        private static Dataset createDataset(RealMatrix m) {
1222                DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, m.getRowDimension(), m.getColumnDimension());
1223                if (m instanceof Array2DRowRealMatrix) {
1224                        double[][] data = ((Array2DRowRealMatrix) m).getDataRef();
1225                        IndexIterator it = r.getIterator(true);
1226                        int[] pos = it.getPos();
1227                        while (it.hasNext()) {
1228                                r.setAbs(it.index, data[pos[0]][pos[1]]);
1229                        }
1230                } else {
1231                        IndexIterator it = r.getIterator(true);
1232                        int[] pos = it.getPos();
1233                        while (it.hasNext()) {
1234                                r.setAbs(it.index, m.getEntry(pos[0], pos[1]));
1235                        }
1236                }
1237                return r;
1238        }
1239}