001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2017 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *    Tom Schoonjans - min and max methods
012 *******************************************************************************/
013
014package org.eclipse.january.dataset;
015
016import java.util.ArrayList;
017import java.util.Arrays;
018import java.util.List;
019
020import org.eclipse.january.DatasetException;
021import org.slf4j.Logger;
022import org.slf4j.LoggerFactory;
023
024/**
025 * Mathematics class for lazy datasets
026 */
027public final class LazyMaths {
028
029        private static final String DUPLICATE_AXIS_ERROR = "Axis arguments must be unique";
030        private static final String TOO_MANY_AXES_ERROR = "Number of axes cannot be greater than the rank";
031
032        private LazyMaths() {
033
034        }
035
036        /**
037         * Setup the logging facilities
038         */
039        protected static final Logger logger = LoggerFactory.getLogger(LazyMaths.class);
040
041        // TODO Uncomment this next line when minimum JDK is set to 1.8
042        // @FunctionalInterface
043        private static interface IMathOperation {
044                void execute(IDataset a, IDataset b, Dataset c);
045        }
046
047        private enum MathOperation implements IMathOperation {
048                // TODO use lambdas here when moving to Java 8
049                MAX(new IMathOperation() {
050                        @Override
051                        public void execute(IDataset a, IDataset b, Dataset c) {
052                                Maths.maximum(a, b, c);
053                        }
054                }, "maximum"),
055                MIN(new IMathOperation() {
056                        @Override
057                        public void execute(IDataset a, IDataset b, Dataset c) {
058                                Maths.minimum(a, b, c);
059                        }
060                }, "minimum");
061
062                private final IMathOperation operation;
063                private final String operationName;
064
065                private MathOperation(IMathOperation operation, String operationName) {
066                        this.operation = operation;
067                        this.operationName = operationName;
068                }
069                
070                @Override
071                public void execute(IDataset a, IDataset b, Dataset c) {
072                        operation.execute(a, b, c);
073                }
074
075                /**
076                 * @return the operationName
077                 */
078                public String getOperationName() {
079                        return operationName;
080                }
081
082        }
083
084        private static int[] requireSortedAxes(final ILazyDataset data, int[] axes) {
085                int rank = data.getRank();
086                if (axes == null || axes.length == 0) { // take to mean use all axes
087                        axes = new int[rank];
088                        for (int i = 0; i < rank; i++) {
089                                axes[i] = i;
090                        }
091                } else {
092                        Arrays.sort(axes);
093                        if (rank < axes.length) {
094                                logger.error(TOO_MANY_AXES_ERROR);
095                                throw new IllegalArgumentException(TOO_MANY_AXES_ERROR);
096                        }
097                                
098                        for (int axisIndex = 0 ; axisIndex < axes.length ; axisIndex++) {
099                                if (axes.length > 1 && axisIndex > 0 && axes[axisIndex] == axes[axisIndex-1]) {
100                                        logger.error(DUPLICATE_AXIS_ERROR);
101                                        throw new IllegalArgumentException(DUPLICATE_AXIS_ERROR);
102                                }
103                                axes[axisIndex] = ShapeUtils.checkAxis(rank, axes[axisIndex]);
104                        }
105                }
106                return axes;
107        }
108
109        private static Dataset maxmin(final ILazyDataset data, MathOperation operation, int...axes) throws DatasetException {
110                axes = requireSortedAxes(data, axes);
111                
112                // we will be working here with the "ignoreAxes" instead to improve performance dramatically
113                int[] ignoreAxes = new int[data.getRank()-axes.length];
114                
115                int k = 0;
116                for (int i = 0 ; i < data.getRank() ; i++) {
117                        if (Arrays.binarySearch(axes, i) < 0)
118                                ignoreAxes[k++] = i;
119                }
120                
121                final int[] oldShape = data.getShape();
122
123                SliceND sa = new SliceND(oldShape);
124                SliceNDIterator it = new SliceNDIterator(sa, ignoreAxes);
125                Dataset result = null;
126                
127                while (it.hasNext()) {
128                        SliceND currentSlice = it.getCurrentSlice();
129                        IDataset slice = data.getSlice(currentSlice);
130                        if (result == null)
131                                result = DatasetUtils.convertToDataset(slice);
132                        else
133                                operation.execute(result, slice, result);
134                }
135                if (result != null) {
136                        result.setName(operation.getOperationName());
137                        result.squeeze();
138                }
139                return result;
140        }
141
142        /**
143         * @param data
144         * @param axes (can be negative). If null or empty then use all axes
145         * @return maximum along axes in lazy dataset
146         * @throws DatasetException
147         * @since 2.1
148         */
149        public static Dataset max(final ILazyDataset data, int... axes) throws DatasetException {
150                if (data instanceof Dataset) {
151                        Dataset tmp = (Dataset) data;
152                        axes = requireSortedAxes(data, axes);
153                        for (int i = axes.length - 1; i >= 0; i--) {
154                                tmp = tmp.max(axes[i]);
155                        }
156                        
157                        return tmp;
158                }
159                return maxmin(data, MathOperation.MAX, axes);
160        }
161
162        /**
163         * @param data
164         * @param axes (can be negative). If null or empty then use all axes
165         * @return minimum along axes in lazy dataset
166         * @throws DatasetException
167         * @since 2.1
168         */
169        public static Dataset min(final ILazyDataset data, int... axes) throws DatasetException {
170                if (data instanceof Dataset) {
171                        Dataset tmp = (Dataset) data;
172                        axes = requireSortedAxes(data, axes);
173                        for (int i = axes.length - 1; i >= 0; i--) {
174                                tmp = tmp.min(axes[i]);
175                        }
176                        
177                        return tmp;
178                }
179                return maxmin(data, MathOperation.MIN, axes);
180        }
181
182        /**
183         * @param data
184         * @param axis (can be negative)
185         * @return sum along axis in lazy dataset
186         * @throws DatasetException 
187         */
188        public static Dataset sum(final ILazyDataset data, int axis) throws DatasetException {
189                if (data instanceof Dataset)
190                        return ((Dataset) data).sum(axis);
191                int[][] sliceInfo = new int[3][];
192                int[] shape = data.getShape();
193                final Dataset result = prepareDataset(axis, shape, sliceInfo);
194
195                final int[] start = sliceInfo[0];
196                final int[] stop = sliceInfo[1];
197                final int[] step = sliceInfo[2];
198                final int length = shape[axis];
199
200                for (int i = 0; i < length; i++) {
201                        start[axis] = i;
202                        stop[axis] = i + 1;
203                        result.iadd(data.getSlice(start, stop, step));
204                }
205
206                result.setShape(ShapeUtils.squeezeShape(shape, axis));
207                return result;
208        }
209
210        /**
211         * @param data
212         * @param ignoreAxes axes to ignore
213         * @return sum when given axes are ignored in lazy dataset
214         * @throws DatasetException 
215         * @since 2.0
216         */
217        public static Dataset sum(final ILazyDataset data, int... ignoreAxes) throws DatasetException {
218                return sum(data, true, ignoreAxes);
219        }
220        
221        /**
222         * @param data
223         * @param ignore if true, ignore the provided axes, otherwise use only the provided axes 
224         * @param axes axes to ignore or accept, depending on the preceding flag
225         * @return sum
226         * @throws DatasetException 
227         * @since 2.0
228         */
229        public static Dataset sum(final ILazyDataset data, boolean ignore, int... axes) throws DatasetException {
230                Arrays.sort(axes); // ensure they are properly sorted
231        
232                ILazyDataset rv = data;
233                
234                if (ignore) {
235                        List<Integer> goodAxes = new ArrayList<>();
236                        for (int i = 0 ; i < data.getRank() ; i++) {
237                                boolean found = false;
238                                for (int j = 0 ; j < axes.length ; j++) {
239                                        if (i == axes[j]) {
240                                                found = true;
241                                                break;
242                                        }
243                                }
244                                if (!found)             
245                                        goodAxes.add(i);
246                        }
247
248                        for (int i = 0 ; i < goodAxes.size() ; i++) {
249                                rv = sum(rv, goodAxes.get(i) - i);
250                        }
251                } else {
252                        for (int i = 0 ; i < axes.length ; i++) {
253                                rv = sum(rv, axes[i] - i);
254                        }
255                }
256                return DatasetUtils.sliceAndConvertLazyDataset(rv);
257        }
258        
259        /**
260         * @param data
261         * @param axis (can be negative)
262         * @return product along axis in lazy dataset
263         * @throws DatasetException 
264         */
265        public static Dataset product(final ILazyDataset data, int axis) throws DatasetException {
266                int[][] sliceInfo = new int[3][];
267                int[] shape = data.getShape();
268                final Dataset result = prepareDataset(axis, shape, sliceInfo);
269                result.fill(1);
270
271                final int[] start = sliceInfo[0];
272                final int[] stop = sliceInfo[1];
273                final int[] step = sliceInfo[2];
274                final int length = shape[axis];
275
276                for (int i = 0; i < length; i++) {
277                        start[axis] = i;
278                        stop[axis] = i + 1;
279                        result.imultiply(data.getSlice(start, stop, step));
280                }
281
282                result.setShape(ShapeUtils.squeezeShape(shape, axis));
283                return result;
284        }
285
286        /**
287         * @param start
288         * @param stop inclusive
289         * @param data
290         * @param ignoreAxes
291         * @return mean when given axes are ignored in lazy dataset
292         * @throws DatasetException 
293         */
294        public static Dataset mean(int start, int stop, ILazyDataset data, int... ignoreAxes) throws DatasetException {
295                int[] shape = data.getShape();
296                PositionIterator iter = new PositionIterator(shape, ignoreAxes);
297                int[] pos = iter.getPos();
298                boolean[] omit = iter.getOmit();
299
300                int rank = shape.length;
301                int[] st = new int[rank];
302                Arrays.fill(st, 1);
303                int[] end = new int[rank];
304
305                RunningAverage av = null;
306                int c = 0;
307                while (iter.hasNext() && c < stop) {
308                        if (c++ < start) continue;
309                        for (int i = 0; i < rank; i++) {
310                                end[i] = omit[i] ? shape[i] : pos[i] + 1;
311                        }
312                        IDataset ds = data.getSlice(pos, end, st);
313                        if (av == null) {
314                                av = new RunningAverage(ds);
315                        } else {
316                                av.update(ds);
317                        }
318                }
319
320                return  av != null ? av.getCurrentAverage().squeeze() : null;
321        }
322        
323        public static Dataset mean(ILazyDataset data, int... ignoreAxes) throws DatasetException {
324                return mean(0, Integer.MAX_VALUE -1 , data, ignoreAxes);
325        }
326
327        private static Dataset prepareDataset(int axis, int[] shape, int[][] sliceInfo) {
328                int rank = shape.length;
329                axis = ShapeUtils.checkAxis(rank, axis);
330
331                sliceInfo[0] = new int[rank];
332                sliceInfo[1] = shape.clone();
333                sliceInfo[2] = new int[rank];
334                Arrays.fill(sliceInfo[2], 1);
335
336                final int[] nshape = shape.clone();
337                nshape[axis] = 1;
338
339                return DatasetFactory.zeros(nshape);
340        }
341}