3333#include < cassert>
3434#include < cstdlib>
3535#include < cmath>
36+ #include < cstdint>
3637
3738namespace VecExpr {
3839
@@ -157,21 +158,28 @@ FORCEINLINE auto FUN(decay_t<S> u, const CVecExpr<V,S>& v) \
157158 RETURNS( EXPR<Bcast<S>,V,S>(Bcast<S>(u), v.derived()) \
158159) \
159160
160- /* --- std::max/min have issues (maybe because they return by reference).
161- * For AD codi::max/min need to be used to avoid issues in debug builds. ---*/
162-
163- #if defined(CODI_REVERSE_TYPE) || defined(CODI_FORWARD_TYPE)
164- #define max_impl math::max
165- #define min_impl math::min
166- #else
167- #define max_impl (a,b ) a<b? Scalar(b) : Scalar(a)
168- #define min_impl (a,b ) b<a? Scalar(b) : Scalar(a)
169- #endif
170- MAKE_BINARY_FUN (max, max_, max_impl)
171- MAKE_BINARY_FUN (min, min_, min_impl)
161+ /* --- std::max/min have issues (because they return by reference).
162+ * fmin and fmax return by value and thus are fine, but they would force
163+ * conversions to double, to avoid that we provide integer overloads.
164+ * We use int32/64 instead of int/long to avoid issues with Windows,
165+ * where long is 32 bits (instead of 64 bits). ---*/
166+
167+ #define MAKE_FMINMAX_OVERLOADS (TYPE ) \
168+ FORCEINLINE TYPE fmax (TYPE a, TYPE b) { return a<b? b : a; } \
169+ FORCEINLINE TYPE fmin (TYPE a, TYPE b) { return a<b? a : b; }
170+ MAKE_FMINMAX_OVERLOADS (int32_t )
171+ MAKE_FMINMAX_OVERLOADS (int64_t )
172+ MAKE_FMINMAX_OVERLOADS (uint32_t )
173+ MAKE_FMINMAX_OVERLOADS (uint64_t )
174+ /* --- Make the float and double versions of fmin/max available in this
175+ * namespace to avoid ambiguous overloads. ---*/
176+ using std::fmax;
177+ using std::fmin;
178+ #undef MAKE_FMINMAX_OVERLOADS
179+
180+ MAKE_BINARY_FUN (fmax, max_, fmax)
181+ MAKE_BINARY_FUN (fmin, min_, fmin)
172182MAKE_BINARY_FUN (pow, pow_, math::pow)
173- #undef max_impl
174- #undef min_impl
175183
176184/* --- sts::plus and co. were tried, the code was horrendous (due to the forced
177185 * conversion between different types) and creating functions for these ops
@@ -190,20 +198,25 @@ MAKE_BINARY_FUN(operator/, div_, div_impl)
190198#undef mul_impl
191199#undef div_impl
192200
193- /* --- Relational operators need to be cast to the scalar type to allow vectorization. ---*/
194-
195- #define le_impl (a,b ) Scalar(a<=b)
196- #define ge_impl (a,b ) Scalar(a>=b)
197- #define eq_impl (a,b ) Scalar(a==b)
198- #define ne_impl (a,b ) Scalar(a!=b)
199- #define lt_impl (a,b ) Scalar(a<b)
200- #define gt_impl (a,b ) Scalar(a>b)
201+ /* --- Relational operators need to be cast to the scalar type to allow vectorization.
202+ * TO_PASSIVE is used to convert active scalars to passive, which CoDi will then capture
203+ * by value in its expressions, and thus dangling references are avoided. No AD info
204+ * is lost since these operators are non-differentiable. ---*/
205+
206+ #define TO_PASSIVE (IMPL ) SU2_TYPE::Passive<Scalar>::Value(IMPL)
207+ #define le_impl (a,b ) TO_PASSIVE(a<=b)
208+ #define ge_impl (a,b ) TO_PASSIVE(a>=b)
209+ #define eq_impl (a,b ) TO_PASSIVE(a==b)
210+ #define ne_impl (a,b ) TO_PASSIVE(a!=b)
211+ #define lt_impl (a,b ) TO_PASSIVE(a<b)
212+ #define gt_impl (a,b ) TO_PASSIVE(a>b)
201213MAKE_BINARY_FUN (operator <=, le_, le_impl)
202214MAKE_BINARY_FUN (operator >=, ge_, ge_impl)
203215MAKE_BINARY_FUN (operator ==, eq_, eq_impl)
204216MAKE_BINARY_FUN (operator !=, ne_, ne_impl)
205217MAKE_BINARY_FUN (operator <, lt_, lt_impl)
206218MAKE_BINARY_FUN (operator >, gt_, gt_impl)
219+ #undef TO_PASSIVE
207220#undef le_impl
208221#undef ge_impl
209222#undef eq_impl
0 commit comments