001/*-
002 * Copyright 2016 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import java.util.ArrayList;
013import java.util.Arrays;
014import java.util.List;
015
016public final class BroadcastUtils {
017
018        /**
019         * Calculate shapes for broadcasting
020         * @param oldShape
021         * @param size
022         * @param newShape
023         * @return broadcasted shape and full new shape or null if it cannot be done
024         */
025        public static int[][] calculateBroadcastShapes(int[] oldShape, int size, int... newShape) {
026                if (newShape == null)
027                        return null;
028        
029                int brank = newShape.length;
030                if (brank == 0) {
031                        if (size == 1)
032                                return new int[][] {oldShape, newShape};
033                        return null;
034                }
035        
036                if (Arrays.equals(oldShape, newShape))
037                        return new int[][] {oldShape, newShape};
038        
039                int offset = brank - oldShape.length;
040                if (offset < 0) { // when new shape is incomplete
041                        newShape = padShape(newShape, -offset);
042                        offset = 0;
043                }
044        
045                int[] bshape;
046                if (offset > 0) { // new shape has extra dimensions
047                        bshape = padShape(oldShape, offset);
048                } else {
049                        bshape = oldShape;
050                }
051        
052                for (int i = 0; i < brank; i++) {
053                        if (newShape[i] != bshape[i] && bshape[i] != 1 && newShape[i] != 1) {
054                                return null;
055                        }
056                }
057        
058                return new int[][] {bshape, newShape};
059        }
060
061        /**
062         * Pad shape by prefixing with ones
063         * @param shape
064         * @param padding
065         * @return new shape or old shape if padding is zero
066         */
067        public static int[] padShape(final int[] shape, final int padding) {
068                if (padding < 0)
069                        throw new IllegalArgumentException("Padding must be zero or greater");
070        
071                if (padding == 0)
072                        return shape;
073        
074                final int[] nshape = new int[shape.length + padding];
075                Arrays.fill(nshape, 1);
076                System.arraycopy(shape, 0, nshape, padding, shape.length);
077                return nshape;
078        }
079
080        /**
081         * Take in shapes and broadcast them to same rank
082         * @param shapes
083         * @return list of broadcasted shapes plus the first entry is the maximum shape
084         */
085        public static List<int[]> broadcastShapes(int[]... shapes) {
086                int maxRank = -1;
087                for (int[] s : shapes) {
088                        if (s == null)
089                                continue;
090        
091                        int r = s.length;
092                        if (r > maxRank) {
093                                maxRank = r;
094                        }
095                }
096        
097                List<int[]> newShapes = new ArrayList<int[]>();
098                for (int[] s : shapes) {
099                        if (s == null)
100                                continue;
101                        newShapes.add(padShape(s, maxRank - s.length));
102                }
103        
104                int[] maxShape = new int[maxRank];
105                for (int i = 0; i < maxRank; i++) {
106                        int m = -1;
107                        for (int[] s : newShapes) {
108                                int l = s[i];
109                                if (l > m) {
110                                        if (m > 1) {
111                                                throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum");
112                                        }
113                                        m = l;
114                                }
115                        }
116                        maxShape[i] = m;
117                }
118
119                checkShapes(maxShape, newShapes);
120                newShapes.add(0, maxShape);
121                return newShapes;
122        }
123
124        /**
125         * Take in shapes and broadcast them to maximum shape
126         * @param maxShape
127         * @param shapes
128         * @return list of broadcasted shapes
129         */
130        public static List<int[]> broadcastShapesToMax(int[] maxShape, int[]... shapes) {
131                int maxRank = maxShape.length;
132                for (int[] s : shapes) {
133                        if (s == null)
134                                continue;
135        
136                        int r = s.length;
137                        if (r > maxRank) {
138                                throw new IllegalArgumentException("A shape exceeds given rank of maximum shape");
139                        }
140                }
141        
142                List<int[]> newShapes = new ArrayList<int[]>();
143                for (int[] s : shapes) {
144                        if (s == null)
145                                continue;
146                        newShapes.add(padShape(s, maxRank - s.length));
147                }
148
149                checkShapes(maxShape, newShapes);
150                return newShapes;
151        }
152
153        private static void checkShapes(int[] maxShape, List<int[]> newShapes) {
154                for (int i = 0; i < maxShape.length; i++) {
155                        int m = maxShape[i];
156                        for (int[] s : newShapes) {
157                                int l = s[i];
158                                if (l != 1 && l != m) {
159                                        throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum");
160                                }
161                        }
162                }
163        }
164
165        static Dataset createDataset(final Dataset a, final Dataset b, final int[] shape) {
166                final Class<? extends Dataset> rc;
167                final int ar = a.getRank();
168                final int br = b.getRank();
169                Class<? extends Dataset> tc = InterfaceUtils.getBestInterface(a.getClass(), b.getClass());
170                if (ar == 0 ^ br == 0) { // ignore type of zero-rank dataset unless it's floating point 
171                        if (ar == 0) {
172                                rc = a.hasFloatingPointElements() ? tc : b.getClass();
173                        } else {
174                                rc = b.hasFloatingPointElements() ? tc : a.getClass();
175                        }
176                } else {
177                        rc = tc;
178                }
179                final int ia = a.getElementsPerItem();
180                final int ib = b.getElementsPerItem();
181        
182                return DatasetFactory.zeros(ia > ib ? ia : ib, rc, shape);
183        }
184
185        /**
186         * Check if dataset item sizes are compatible
187         * <p>
188         * Dataset a is considered compatible with the output dataset if any of the
189         * conditions are true:
190         * <ul>
191         * <li>o is undefined</li>
192         * <li>a has item size equal to o's</li>
193         * <li>a has item size equal to 1</li>
194         * <li>o has item size equal to 1</li>
195         * </ul>
196         * @param a input dataset a
197         * @param o output dataset (can be null)
198         */
199        static void checkItemSize(Dataset a, Dataset o) {
200                final int isa = a.getElementsPerItem();
201                if (o != null) {
202                        final int iso = o.getElementsPerItem();
203                        if (isa != iso && isa != 1 && iso != 1) {
204                                throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'");
205                        }
206                }
207        }
208
209        /**
210         * Check if dataset item sizes are compatible
211         * <p>
212         * Dataset a is considered compatible with the output dataset if any of the
213         * conditions are true:
214         * <ul>
215         * <li>a has item size equal to b's</li>
216         * <li>a has item size equal to 1</li>
217         * <li>b has item size equal to 1</li>
218         * <li>a or b are single-valued</li>
219         * </ul>
220         * and, o is undefined, or any of the following are true:
221         * <ul>
222         * <li>o has item size equal to maximum of a and b's</li>
223         * <li>o has item size equal to 1</li>
224         * <li>a and b have item sizes of 1</li>
225         * </ul>
226         * @param a input dataset a
227         * @param b input dataset b
228         * @param o output dataset
229         */
230        static void checkItemSize(Dataset a, Dataset b, Dataset o) {
231                final int isa = a.getElementsPerItem();
232                final int isb = b.getElementsPerItem();
233                if (isa != isb && isa != 1 && isb != 1) {
234                        // exempt single-value dataset case too
235                        if ((isa == 1 || b.getSize() != 1) && (isb == 1 || a.getSize() != 1) ) {
236                                throw new IllegalArgumentException("Can not broadcast where number of elements per item mismatch and one does not equal another");
237                        }
238                }
239                if (o != null && o.getDType() != Dataset.BOOL) {
240                        final int ism = Math.max(isa, isb);
241                        final int iso = o.getElementsPerItem();
242                        if (iso != ism && iso != 1 && ism != 1) {
243                                throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'");
244                        }
245                }
246        }
247
248        /**
249         * Create a stride array from a dataset to a broadcast shape
250         * @param a dataset
251         * @param broadcastShape
252         * @return stride array
253         */
254        public static int[] createBroadcastStrides(Dataset a, final int[] broadcastShape) {
255                return createBroadcastStrides(a.getElementsPerItem(), a.getShapeRef(), a.getStrides(), broadcastShape);
256        }
257
258        /**
259         * Create a stride array from a dataset to a broadcast shape
260         * @param isize
261         * @param oShape original shape
262         * @param oStride original stride
263         * @param broadcastShape
264         * @return stride array
265         */
266        public static int[] createBroadcastStrides(final int isize, final int[] oShape, final int[] oStride, final int[] broadcastShape) {
267                int rank = oShape.length;
268                if (broadcastShape.length != rank) {
269                        throw new IllegalArgumentException("Dataset must have same rank as broadcast shape");
270                }
271        
272                int[] stride = new int[rank];
273                if (oStride == null) {
274                        int s = isize;
275                        for (int j = rank - 1; j >= 0; j--) {
276                                if (broadcastShape[j] == oShape[j]) {
277                                        stride[j] = s;
278                                        s *= oShape[j];
279                                } else {
280                                        stride[j] = 0;
281                                }
282                        }
283                } else {
284                        for (int j = 0; j < rank; j++) {
285                                if (broadcastShape[j] == oShape[j]) {
286                                        stride[j] = oStride[j];
287                                } else {
288                                        stride[j] = 0;
289                                }
290                        }
291                }
292        
293                return stride;
294        }
295
296        /**
297         * Converts and broadcast all objects as datasets of same shape
298         * @param objects
299         * @return all as broadcasted to same shape
300         */
301        public static Dataset[] convertAndBroadcast(Object... objects) {
302                final int n = objects.length;
303
304                Dataset[] datasets = new Dataset[n];
305                int[][] shapes = new int[n][];
306                for (int i = 0; i < n; i++) {
307                        Dataset d = DatasetFactory.createFromObject(objects[i]);
308                        datasets[i] = d;
309                        shapes[i] = d.getShapeRef();
310                }
311
312                List<int[]> nShapes = BroadcastUtils.broadcastShapes(shapes);
313                int[] mshape = nShapes.get(0);
314                for (int i = 0; i < n; i++) {
315                        datasets[i] = datasets[i].getBroadcastView(mshape);
316                }
317
318                return datasets;
319        }
320}