Line data Source code
1 : //------------------------------------------------------------------------------
2 : // experimental/algorithm/LAGraph_argminmax
3 : //------------------------------------------------------------------------------
4 :
5 : // LAGraph, (c) 2019-2025 by The LAGraph Contributors, All Rights Reserved.
6 : // SPDX-License-Identifier: BSD-2-Clause
7 : // See additional acknowledgments in the LICENSE file,
8 : // or contact permission@sei.cmu.edu for the full terms.
9 :
10 : // Contributed by Olumayowa Olowomeye and Tim Davis, Texas A&M University
11 :
12 : //------------------------------------------------------------------------------
13 :
14 : // TODO: not ready for src, need to use IndexBinaryOps instead.
15 : // See argmin/argmax in GraphBLAS/@GrB/*/*/gbargminmax.c.
16 :
17 : #include "LG_internal.h"
18 : #include "LAGraphX.h"
19 : #include <LAGraph.h>
20 :
21 : // #define USAGE "usage: [x,p] = LAGraph_argminmax (A, minmax, dim)"
22 :
23 : //------------------------------------------------------------------------------
24 : // argminmax: compute argmin/max of each row/column of A
25 : //------------------------------------------------------------------------------
26 :
27 : #undef LG_FREE_WORK
28 : #define LG_FREE_WORK \
29 : { \
30 : GrB_free (&G) ; \
31 : GrB_free (&D) ; \
32 : GrB_free (&y) ; \
33 : }
34 :
35 : #undef LG_FREE_ALL
36 : #define LG_FREE_ALL \
37 : { \
38 : LG_FREE_WORK ; \
39 : GrB_free (x) ; \
40 : GrB_free (p) ; \
41 : }
42 :
43 : #if LAGRAPH_SUITESPARSE
44 :
45 352 : int argminmax
46 : (
47 : // output
48 : GrB_Matrix *x, // min/max value in each row/col of A
49 : GrB_Matrix *p, // index of min/max value in each row/col of A
50 : // input
51 : GrB_Matrix A,
52 : int dim, // dim=1: cols of A, dim=2: rows of A
53 : GrB_Semiring minmax_first, // MIN_FIRST_type or MAX_FIRST_type semiring
54 : GrB_Semiring any_equal, // ANY_EQ semiring
55 : char* msg
56 : )
57 : {
58 :
59 : //--------------------------------------------------------------------------
60 : // get the size and type of A
61 : //--------------------------------------------------------------------------
62 :
63 352 : GrB_Matrix y = NULL ;
64 352 : GrB_Matrix G = NULL, D = NULL ;
65 352 : (*x) = NULL ;
66 352 : (*p) = NULL ;
67 :
68 : GrB_Index nrows, ncols ;
69 352 : GRB_TRY (GrB_Matrix_nrows (&nrows, A)) ;
70 352 : GRB_TRY (GrB_Matrix_ncols (&ncols, A)) ;
71 : GrB_Type type ;
72 352 : GRB_TRY (GxB_Matrix_type (&type, A)) ;
73 :
74 : //--------------------------------------------------------------------------
75 : // create outputs x and p, and the iso full vector y
76 : //--------------------------------------------------------------------------
77 :
78 352 : GrB_Index n = (dim == 2) ? ncols : nrows ;
79 352 : GrB_Index m = (dim == 2) ? nrows : ncols ;
80 352 : GrB_Descriptor desc = (dim == 2) ? NULL : GrB_DESC_T0 ;
81 352 : GRB_TRY (GrB_Matrix_new (x, type, m, 1)) ;
82 352 : GRB_TRY (GrB_Matrix_new (&y, type, n, 1)) ;
83 352 : GRB_TRY (GrB_Matrix_new (p, GrB_INT64, m, 1)) ;
84 :
85 : // y (:) = 1, an iso full vector
86 352 : GRB_TRY (GrB_Matrix_assign_INT64 (y, NULL, NULL, 1, GrB_ALL, n, GrB_ALL, 1,
87 : NULL)) ;
88 :
89 : //--------------------------------------------------------------------------
90 : // compute x = min/max(A)
91 : //--------------------------------------------------------------------------
92 :
93 : // for dim=1: x = min/max (A) where x(j) = min/max (A (:,j))
94 : // for dim=2: x = min/max (A) where x(i) = min/max (A (i,:))
95 :
96 352 : GRB_TRY (GrB_mxm (*x, NULL, NULL, minmax_first, A, y, desc)) ;
97 :
98 : //--------------------------------------------------------------------------
99 : // D = diag (x)
100 : //--------------------------------------------------------------------------
101 :
102 : // note: typecasting from an m-by-1 GrB_Matrix to a GrB_Vector is
103 : // not allowed by the GraphBLAS C API, but it can be done in SuiteSparse.
104 : // A more portable method would construct x as a GrB_Vector,
105 : // but using x as a GrB_Matrix simplifies the gb_export.
106 :
107 352 : GRB_TRY (GrB_Matrix_diag (&D, (GrB_Vector) *x, 0)) ;
108 :
109 : //--------------------------------------------------------------------------
110 : // compute G, where G(i,j)=1 if A(i,j) is the min/max in its row/col
111 : //--------------------------------------------------------------------------
112 :
113 352 : GRB_TRY (GrB_Matrix_new (&G, GrB_BOOL, nrows, ncols)) ;
114 352 : if (dim == 1)
115 : {
116 : // G = A*D using the ANY_EQ_type semiring
117 264 : GRB_TRY (GrB_mxm (G, NULL, NULL, any_equal, A, D, NULL)) ;
118 : }
119 : else
120 : {
121 : // G = D*A using the ANY_EQ_type semiring
122 88 : GRB_TRY (GrB_mxm (G, NULL, NULL, any_equal, D, A, NULL)) ;
123 : }
124 :
125 : // drop explicit zeros from G
126 352 : GRB_TRY (GrB_Matrix_select_BOOL (G, NULL, NULL, GrB_VALUENE_BOOL, G, 0,
127 : NULL)) ;
128 :
129 : //--------------------------------------------------------------------------
130 : // extract the positions of the entries in G
131 : //--------------------------------------------------------------------------
132 :
133 : // for dim=1: find the position of the min/max entry in each column:
134 : // p = G'*y, so that p(j) = i if x(j) = A(i,j) = min/max (A (:,j)).
135 :
136 : // for dim=2: find the position of the min/max entry in each row:
137 : // p = G*y, so that p(i) = j if x(i) = A(i,j) = min/max (A (i,:)).
138 :
139 : // Use the SECONDI operator since built-in indexing is 0-based. The ANY
140 : // monoid would be faster, but this uses MIN monoid so that the result for
141 : // the user is repeatable.
142 352 : GRB_TRY (GrB_mxm (*p, NULL, NULL, GxB_MIN_SECONDI_INT64, G, y, desc)) ;
143 :
144 : //--------------------------------------------------------------------------
145 : // free workspace
146 : //--------------------------------------------------------------------------
147 :
148 352 : GrB_Matrix_free (&D) ;
149 352 : GrB_Matrix_free (&G) ;
150 352 : GrB_Matrix_free (&y) ;
151 352 : return (GrB_SUCCESS) ;
152 : }
153 : #endif
154 :
155 : //------------------------------------------------------------------------------
156 : // gbargminmax: mexFunction to compute the argmin/max of each row/column of A
157 : //------------------------------------------------------------------------------
158 :
159 : #undef LG_FREE_WORK
160 : #define LG_FREE_WORK \
161 : { \
162 : GrB_free (&x1) ; \
163 : GrB_free (&p1) ; \
164 : GrB_free (&x) ; \
165 : GrB_free (&p) ; \
166 : }
167 :
168 : #undef LG_FREE_ALL
169 : #define LG_FREE_ALL \
170 : { \
171 : LG_FREE_WORK ; \
172 : GrB_free (x_result) ; \
173 : GrB_free (p_result) ; \
174 : }
175 :
176 264 : int LAGraph_argminmax
177 : (
178 : // output
179 : GrB_Vector *x_result, // min/max value in each row/col of A
180 : GrB_Vector *p_result, // index of min/max value in each row/col of A
181 : // input
182 : GrB_Matrix A,
183 : int dim, // dim=1: cols of A, dim=2: rows of A
184 : bool is_min,
185 : char *msg
186 : )
187 : {
188 : #if LAGRAPH_SUITESPARSE
189 :
190 : //--------------------------------------------------------------------------
191 : // check inputs
192 : //--------------------------------------------------------------------------
193 :
194 : // TODO: need LAGraph error checks here
195 :
196 264 : GrB_Matrix x = NULL, p = NULL, x1 = NULL, p1 = NULL ;
197 264 : (*x_result) = NULL ;
198 264 : (*p_result) = NULL ;
199 :
200 : //--------------------------------------------------------------------------
201 : // select the semirings
202 : //--------------------------------------------------------------------------
203 :
204 : GrB_Type type ;
205 264 : GRB_TRY (GxB_Matrix_type (&type, A)) ;
206 : GrB_Semiring minmax_first, any_equal ;
207 264 : if (is_min)
208 : {
209 :
210 : //----------------------------------------------------------------------
211 : // semirings for argmin
212 : //----------------------------------------------------------------------
213 :
214 : // TODO: use GrB or LAGraph_* semirings when possible
215 :
216 132 : if (type == GrB_BOOL)
217 : {
218 12 : minmax_first = GxB_LAND_FIRST_BOOL ;
219 12 : any_equal = GxB_ANY_EQ_BOOL ;
220 : }
221 120 : else if (type == GrB_INT8)
222 : {
223 12 : minmax_first = GrB_MIN_FIRST_SEMIRING_INT8 ;
224 12 : any_equal = GxB_ANY_EQ_INT8 ;
225 : }
226 108 : else if (type == GrB_INT16)
227 : {
228 12 : minmax_first = GrB_MIN_FIRST_SEMIRING_INT16 ;
229 12 : any_equal = GxB_ANY_EQ_INT16 ;
230 : }
231 96 : else if (type == GrB_INT32)
232 : {
233 12 : minmax_first = GrB_MIN_FIRST_SEMIRING_INT32 ;
234 12 : any_equal = GxB_ANY_EQ_INT32 ;
235 : }
236 84 : else if (type == GrB_INT64)
237 : {
238 12 : minmax_first = GrB_MIN_FIRST_SEMIRING_INT64 ;
239 12 : any_equal = GxB_ANY_EQ_INT64 ;
240 : }
241 72 : else if (type == GrB_UINT8)
242 : {
243 12 : minmax_first = GrB_MIN_FIRST_SEMIRING_UINT8 ;
244 12 : any_equal = GxB_ANY_EQ_UINT8 ;
245 : }
246 60 : else if (type == GrB_UINT16)
247 : {
248 12 : minmax_first = GrB_MIN_FIRST_SEMIRING_UINT16 ;
249 12 : any_equal = GxB_ANY_EQ_UINT16 ;
250 : }
251 48 : else if (type == GrB_UINT32)
252 : {
253 12 : minmax_first = GrB_MIN_FIRST_SEMIRING_UINT32 ;
254 12 : any_equal = GxB_ANY_EQ_UINT32 ;
255 : }
256 36 : else if (type == GrB_UINT64)
257 : {
258 12 : minmax_first = GrB_MIN_FIRST_SEMIRING_UINT64 ;
259 12 : any_equal = GxB_ANY_EQ_UINT64 ;
260 : }
261 24 : else if (type == GrB_FP32)
262 : {
263 12 : minmax_first = GrB_MIN_FIRST_SEMIRING_FP32 ;
264 12 : any_equal = GxB_ANY_EQ_FP32 ;
265 : }
266 12 : else if (type == GrB_FP64)
267 : {
268 12 : minmax_first = GrB_MIN_FIRST_SEMIRING_FP64 ;
269 12 : any_equal = GxB_ANY_EQ_FP64 ;
270 : }
271 : else
272 : {
273 : // ERROR ("unsupported type") ; //ignoring for now
274 : }
275 :
276 : }
277 : else
278 : {
279 :
280 : //----------------------------------------------------------------------
281 : // semirings for argmax
282 : //----------------------------------------------------------------------
283 :
284 132 : if (type == GrB_BOOL)
285 : {
286 12 : minmax_first = GxB_LOR_FIRST_BOOL ;
287 12 : any_equal = GxB_ANY_EQ_BOOL ;
288 : }
289 120 : else if (type == GrB_INT8)
290 : {
291 12 : minmax_first = GrB_MAX_FIRST_SEMIRING_INT8 ;
292 12 : any_equal = GxB_ANY_EQ_INT8 ;
293 : }
294 108 : else if (type == GrB_INT16)
295 : {
296 12 : minmax_first = GrB_MAX_FIRST_SEMIRING_INT16 ;
297 12 : any_equal = GxB_ANY_EQ_INT16 ;
298 : }
299 96 : else if (type == GrB_INT32)
300 : {
301 12 : minmax_first = GrB_MAX_FIRST_SEMIRING_INT32 ;
302 12 : any_equal = GxB_ANY_EQ_INT32 ;
303 : }
304 84 : else if (type == GrB_INT64)
305 : {
306 12 : minmax_first = GrB_MAX_FIRST_SEMIRING_INT64 ;
307 12 : any_equal = GxB_ANY_EQ_INT64 ;
308 : }
309 72 : else if (type == GrB_UINT8)
310 : {
311 12 : minmax_first = GrB_MAX_FIRST_SEMIRING_UINT8 ;
312 12 : any_equal = GxB_ANY_EQ_UINT8 ;
313 : }
314 60 : else if (type == GrB_UINT16)
315 : {
316 12 : minmax_first = GrB_MAX_FIRST_SEMIRING_UINT16 ;
317 12 : any_equal = GxB_ANY_EQ_UINT16 ;
318 : }
319 48 : else if (type == GrB_UINT32)
320 : {
321 12 : minmax_first = GrB_MAX_FIRST_SEMIRING_UINT32 ;
322 12 : any_equal = GxB_ANY_EQ_UINT32 ;
323 : }
324 36 : else if (type == GrB_UINT64)
325 : {
326 12 : minmax_first = GrB_MAX_FIRST_SEMIRING_UINT64 ;
327 12 : any_equal = GxB_ANY_EQ_UINT64 ;
328 : }
329 24 : else if (type == GrB_FP32)
330 : {
331 12 : minmax_first = GrB_MAX_FIRST_SEMIRING_FP32 ;
332 12 : any_equal = GxB_ANY_EQ_FP32 ;
333 : }
334 12 : else if (type == GrB_FP64)
335 : {
336 12 : minmax_first = GrB_MAX_FIRST_SEMIRING_FP64 ;
337 12 : any_equal = GxB_ANY_EQ_FP64 ;
338 : }
339 : else
340 : {
341 : // ERROR ("unsupported type") ;
342 : }
343 : }
344 :
345 : //--------------------------------------------------------------------------
346 : // compute the argmin/max
347 : //--------------------------------------------------------------------------
348 :
349 264 : if (dim == 0)
350 : {
351 :
352 : //----------------------------------------------------------------------
353 : // scalar argmin/max of all of A
354 : //----------------------------------------------------------------------
355 :
356 : // [x1,p1] = argmin/max of each column of A
357 88 : argminmax (&x1, &p1, A, 1, minmax_first, any_equal, msg) ;
358 : // [x,p] = argmin/max of each entry in x
359 88 : argminmax (&x, &p, x1, 1, minmax_first, any_equal, msg) ;
360 : // get the row and column index of the overall argmin/max of A
361 88 : int64_t I [2] = { 0, 0 } ;
362 : GrB_Index nvals0, nvals1 ;
363 88 : GRB_TRY (GrB_Matrix_nvals (&nvals0, p)) ;
364 88 : GRB_TRY (GrB_Matrix_nvals (&nvals1, p1)) ;
365 88 : if (nvals0 > 0 && nvals1 > 0)
366 : {
367 : // I [0] = p [0], the row index of the global argmin/max of A
368 88 : GRB_TRY (GrB_Matrix_extractElement_INT64 (&(I [0]), p, 0, 0)) ;
369 : // I [1] = p1 [I [0]]
370 : // which is the column index of the global argmin/max of A
371 88 : GRB_TRY (GrB_Matrix_extractElement_INT64 (&(I [1]), p1, I [0], 0)) ;
372 : }
373 :
374 : // free workspace and create p = [row, col]
375 88 : GRB_TRY (GrB_Matrix_free (&x1)) ;
376 88 : GRB_TRY (GrB_Matrix_free (&p1)) ;
377 88 : GRB_TRY (GrB_Matrix_free (&p)) ;
378 88 : GRB_TRY (GrB_Vector_new (x_result, type, 1)) ;
379 88 : GRB_TRY (GrB_Vector_new (p_result, GrB_INT64, 2)) ;
380 88 : if (nvals0 > 0 && nvals1 > 0)
381 : {
382 : // x_result = x (:,0)
383 88 : GRB_TRY (GrB_Col_extract (*x_result, NULL, NULL, x, GrB_ALL,
384 : 1, 0, NULL)) ;
385 : // p_result = [row, col]
386 88 : GRB_TRY (GrB_Vector_setElement_INT64 (*p_result, I [1], 0)) ;
387 88 : GRB_TRY (GrB_Vector_setElement_INT64 (*p_result, I [0], 1)) ;
388 : }
389 :
390 : }
391 176 : else if (dim == 1)
392 : {
393 :
394 : //----------------------------------------------------------------------
395 : // argmin/max of each column of A
396 : //----------------------------------------------------------------------
397 :
398 88 : argminmax (&x, &p, A, 1, minmax_first, any_equal, msg) ;
399 : }
400 : else
401 : {
402 :
403 :
404 : //----------------------------------------------------------------------
405 : // argmin/max of each row of A
406 : //----------------------------------------------------------------------
407 :
408 88 : argminmax (&x, &p, A, 2, minmax_first, any_equal, msg) ;
409 : }
410 :
411 : //--------------------------------------------------------------------------
412 : // return result
413 : //--------------------------------------------------------------------------
414 :
415 264 : if (dim != 0)
416 : {
417 : // x_result = x (:,0)
418 : // p_result = p (:,0)
419 : GrB_Index m ;
420 176 : GRB_TRY (GrB_Matrix_nrows (&m, x)) ;
421 176 : GRB_TRY (GrB_Vector_new (x_result, type, m)) ;
422 176 : GRB_TRY (GrB_Vector_new (p_result, GrB_INT64, m)) ;
423 176 : GRB_TRY (GrB_Col_extract (*x_result, NULL, NULL, x, GrB_ALL, m, 0,
424 : NULL)) ;
425 176 : GRB_TRY (GrB_Col_extract (*p_result, NULL, NULL, p, GrB_ALL, m, 0,
426 : NULL)) ;
427 : }
428 :
429 264 : LG_FREE_WORK ;
430 264 : return (GrB_SUCCESS) ;
431 : #else
432 : return (GrB_NOT_IMPLEMENTED);
433 : #endif
434 : }
435 :
|