#include <cblib/Base.h>
#include <cblib/Log.h>

USE_CB

//=============================================================

// C isnan() is fine but slow in debug
#define ISNAN(f)	(!( (f) == (f) ))

static inline uint32 fbits_as_u32(float f)
{
	#ifdef _DEBUG
	return *((uint32 *)&f);
	#else
	uint32 ret;
	memcpy(&ret,&f,sizeof(float));
	return ret;
	#endif
}

static inline float fbits_from_u32(uint32 u)
{
	#ifdef _DEBUG
	return *((float *)&u);
	#else
	float ret;
	memcpy(&ret,&u,sizeof(float));
	return ret;
	#endif
}

// test if floats are equal exactly in bits
static inline bool fbits_equal(float f,float g)
{
	return memcmp(&f,&g,sizeof(float)) == 0;
}

// check float equality as a float (f == g) but also okay if nan maps to nan
static inline bool float_equal_preserve_nan(float f,float g)
{
	//return (f == g); // no, doesn't work if both nan, compare with nan is always false

	if ( f == g ) return true;

	bool fnan = ISNAN(f);
	bool gnan = ISNAN(g);

	if ( fnan && gnan ) return true;

	return false;
}

//=============================================================

struct floatmap_just_cast
{
	static uint32 forward(float f)
	{
		return fbits_as_u32(f);
	}

	static float reverse(uint32 u)
	{
		return fbits_from_u32(u);
	}
};

struct floatmap_lossless_fix_negatives
{
	//	just xor the propagated sign bit
	//	very obviously lossless and self-inverting
	//	this takes 0.f -> 0 and -0.f -> -1 , so they are preserved but adjacent in integers

	static uint32 forward(float f)
	{
		int32 ret = fbits_as_u32(f);

		int32 propagated_sign = ret>>31;

		ret ^= uint32(propagated_sign)>>1;  // not top bit
		return ret;
	}

	static float reverse(uint32 u)
	{
		int32 propagated_sign = int32(u)>>31;
		
		u ^= uint32(propagated_sign)>>1;  // not top bit

		return fbits_from_u32(u);
	}
};

// generic adapter to handle negatives and nans, and call to a wrapper that only sees positives
//	t_positives_mapper will only receive forward() calls for f > 0.f
template <typename t_positives_mapper>
struct floatmap_adapter_to_positives_only
{
	// values in (7F800000,80800000) (non-inclusive) do not occur in the normal mapped range, so can be used for nan
	//	no single byte repeat value for u32_nan is possible (7F7F7F7F and 80808080 are outside the range)
	//	0x80000000U is right in the middle of the nan range
	static constexpr uint32 u32_nan = 0x80000000U;
	//static constexpr uint32 u32_nan = 0x80008000U; // also a possible option

	static uint32 forward(float f)
	{
		if ( f > 0.f )
		{
			return t_positives_mapper::forward(f);
		}
		else if ( f < 0.f )
		{
			return - (int32) t_positives_mapper::forward(-f);
		}
		else if ( f == 0.f )
		{
			return 0; // also -0.f
		}
		else
		{
			// nan fails all compares so goes here
			
			return u32_nan; // all nans changed to same value
		}
	}
	
	static float reverse(uint32 u)
	{
		if ( u == 0 )
		{
			return 0.f;
		}
		else if ( u == u32_nan )
		{
			return NAN;
		}
		else if ( int32(u) < 0 )
		{
			return - t_positives_mapper::reverse(- (int32)u);
		}
		else // positive
		{
			return t_positives_mapper::reverse(u);
		}
	}
};

struct floatmap_fix_negatives_lossyzeronan : public floatmap_adapter_to_positives_only<floatmap_just_cast>
{
	// "lossless" in float values
	// preserves float_equal_preserve_nan
	//	but not exact bits

	// map -0.f to 0
	// also all nans to canonical nan

};

// use with floatmap_adapter_to_positives_only
struct floatmappositives_lossy_add1
{
	// just add +1
	//	this makes the range [0,1] instead use [1,2] so it gets a linear 23 bits
	//	no negative exponents
	// the bias causes additional round-off error that is not desirable
	//	so this method is generally worse than floatmappositives_lossy_denorm1
	
	static constexpr uint32 u_one = 0x3F800000U;

	static uint32 forward(float f)
	{
		ASSERT( f > 0.f );
		f += 1.f;
		uint32 u = fbits_as_u32(f);

		ASSERT( u >= u_one );
		// can just return u, or subtract off u_one to condense the deadzone at zero
		u = u - u_one + 1;

		return u;
	}

	static float reverse(uint32 u)
	{
		u = u + u_one - 1;

		float f = fbits_from_u32(u);
		ASSERT( f >= 1.f );
		return f - 1.f;
	}
};

// use with floatmap_adapter_to_positives_only
struct floatmappositives_lossy_denorm1
{
	// values below 1.0 are treat as denorm
	//	(eg. you get no negative exponents)
	//	just a linear range of 23 bits in [0,1]
	// the maximum error is 2^-24

	// this is fast on CPUs where denormals are fast

	static uint32 forward(float f)
	{
		ASSERT( f > 0.f );
		f *= 0x1.p-126f;
		uint32 u = fbits_as_u32(f);
		return u;
	}

	static float reverse(uint32 u)
	{
		float f = fbits_from_u32(u);
		f *= 0x1.p126f;
		return f;
	}
};

static inline int32 ftoi_round_positive( float f )
{
	ASSERT( f >= 0.f ); // for negatives would have to detect and do (f - 0.5f)
	return (int32)(f + 0.5f);
}
	
static inline int32 ftoi_round_banker( float f )
{
	// When a conversion is inexact, the value returned is rounded according to the rounding control bits 
	//	in the MXCSR register or the embedded rounding control bits
	//	(this is the same rule used in math done on the float vec registers, so it matches what floatmappositives_lossy_denorm1 does)
	return _mm_cvt_ss2si( _mm_set_ss( f ) );
}

// use with floatmap_adapter_to_positives_only
struct floatmappositives_lossy_logint
{
	// same as lossy_denorm1
	//	alternate implementation that doesn't use FPU denormals, manually makes them

	// output is exactly the same as lossy_denorm1 (except on inf)

	static uint32 forward(float f)
	{
		ASSERT( f > 0.f );
		if ( f >= 1.f )
		{
			uint32 u = fbits_as_u32(f);
			return u - 0x3F000000U;
		}
		else
		{
			// if we use banker rounding, the results are identical to floatmappositives_lossy_denorm1
			//	without, the results are not worse but can be different in the bottom bit
			//	in cases where it doesn't change error if the bottom bit rounds up or down
			//uint32 u = ftoi_round_positive( f * 0x1.p23f );
			uint32 u = ftoi_round_banker( f * 0x1.p23f );
			return u;
		}
	}

	static float reverse(uint32 u)
	{
		if ( u >= 0x800000U )
		{
			float f = fbits_from_u32(u + 0x3F000000U);
			return f;
		}
		else
		{
			float f = float(u) * 0x1.p-23f; 
			return f;
		}
	}
};

struct floatmap_lossy_add1 : public floatmap_adapter_to_positives_only<floatmappositives_lossy_add1>
{

};

struct floatmap_lossy_denorm1 : public floatmap_adapter_to_positives_only<floatmappositives_lossy_denorm1>
{

};

struct floatmap_lossy_logint : public floatmap_adapter_to_positives_only<floatmappositives_lossy_logint>
{

};

//=============================================================
// in cblib would use String and autoprintf
//	but here we'll just strdup and leak it

static char * strdup_binary(uint64 u,int bits)
{
	char temp[100];
	sprintf_binary(temp,u,bits);
	return strdup(temp);
}

static char * strdup_binary_float(float f)
{
	char temp[100];
	sprintf_binary_float(temp,f);
	return strdup(temp);
}

//=============================================================
// test_conversion
// 
// run the forward/reverse on all floats

template <typename t_convert>
void test_conversion(const char * name)
{
	bool equal_exact = true;
	bool equal_float = true;
	double max_err = 0.0;

	lprintf("%-30s: ",name);

	for(uint64 i=0;i<=UINT32_MAX;i++)
	{
		// visit all U32 values but take a prime step so we do a variety of stuff earlier :
		uint32 x = (uint32)(i * 1000000007);

		// start with float f
		float f = fbits_from_u32((uint32)x);

		// run f through forward & reverse conversion
		uint32 u = t_convert::forward(f);
		float g = t_convert::reverse(u);

		// see how f and g compare :

		if ( fbits_equal(f,g) ) continue;

		equal_exact = false;

		if ( float_equal_preserve_nan(f,g) ) continue;

		equal_float = false;
		
		/*
		// for debugging, log :
		lprintf("equal_float = false\n");
		lprintf("u = %08X\n",u);
		lprintf("f = %g = %s\n",f,strdup_binary_float(f));
		lprintf("g = %g = %s\n",g,strdup_binary_float(g));
		*/

		// none of the conversions should change nan-ness :
		DURING_ASSERT( bool fnan = ISNAN(f); bool gnan = ISNAN(g); );
		ASSERT( fnan == gnan );
		// therefore float_equal_preserve_nan returned true so we should not have gotten here
		ASSERT( ! fnan && ! gnan );

		double delta = (double)f - g;
		delta = fabs(delta);

		// make relative delta, but only for values >= 1
		delta /= MAX(fabs(f), 1.0);

		max_err = MAX(max_err,delta);
	}

	if ( equal_exact )
		lprintf("exact bits\n");
	else if ( equal_float )
		lprintf("equal floats\n");
	else
		lprintf("max error : %g = %s\n",
			max_err,strdup_binary_float((float)max_err));
}

//=============================================================

int main(int argc,const char *argv[])
{
	/*
	// test strdup_binary_float:

	const float test_floats[] = { 1.f , 100.f, -INFINITY,  -0.f, NAN,
		1.4012984643e-45f,
		1.00000011920928955f 
	};

	for (float f : test_floats)
	{
		lprintf("%g = %s\n",f,strdup_binary_float(f));
	}
	*/
		
	/*
	// floatmap_fix_negatives_lossyzeronan valid mapping range is between these :

	lprintf("%08X\n",fbits_as_u32(INFINITY)); // 7F800000
	lprintf("%08X\n",-(int32)fbits_as_u32(INFINITY)); // 80800000
	uint32 x = (0x7F800000ULL + 0x80800000ULL)/2;
	lprintf("%08X\n",x);
	/**/

	/*
	// these maps only differ on inf :
	{
		float f = INFINITY;
		
		uint32 u1 = floatmap_lossy_logint::forward(f);
		uint32 u2 = floatmap_lossy_denorm1::forward(f);
		
		ASSERT_RELEASE( u1 == u2 );
	}
	*/

	/*
	for(uint64 i=0;i<=UINT32_MAX;i++)
	{
		// visit all U32 values but take a prime step so we do interesting stuff earlier :
		uint32 x = (uint32)(i * 1000000007);
		
		// start with float f
		float f = fbits_from_u32((uint32)x);

		uint32 u1 = floatmap_lossy_logint::forward(f);
		uint32 u2 = floatmap_lossy_denorm1::forward(f);

		// different for inf
		//ASSERT_RELEASE( u1 == u2 );

		if ( u1 != u2 )
		{
			float g1 = floatmap_lossy_logint::reverse(u1);
			float g2 = floatmap_lossy_denorm1::reverse(u2);
			
			// with rounding the same, only diffeence is how inf is mapped :
			ASSERT_RELEASE( isinf(f) );
			ASSERT_RELEASE( f == g1 && f == g2 );

			#if 0
			if ( isinf(f) )
			{
				// g1 and g2 should both be inf, and should match sign of f

				ASSERT_RELEASE( isinf(g1) );
				ASSERT_RELEASE( isinf(g2) );

				if ( f > 0.f )
					ASSERT_RELEASE( g1 > 0.f && g2 > 0.f ); 
				else
					ASSERT_RELEASE( g1 < 0.f && g2 < 0.f ); 
			}
			else
			{
				ASSERT_RELEASE( fabs(double(f)-g1) == fabs(double(f)-g2) );

				ASSERT_RELEASE( ABS(int32(u1-u2)) <= 1 );

				// floatmap_lossy_logint and floatmap_lossy_denorm1
				//	almost exactly produce the same output, but not quite
				//	due to different rounding rules when they are at a midpoint
				//	(due to round-nearest vs. a banker round)
				// when that happens the uint32 output differs by 1
				//	and the error from the original is identical
				// eg. it rounds differently when then rounding direction does not change error
			}
			#endif
		}
	}
	/**/

	test_conversion<floatmap_just_cast>("just_cast");
	test_conversion<floatmap_lossless_fix_negatives>("lossless_fix_negatives");
	test_conversion<floatmap_fix_negatives_lossyzeronan>("fix_negatives_lossyzeronan");
	test_conversion<floatmap_lossy_logint>("lossy_logint");
	test_conversion<floatmap_lossy_denorm1>("lossy_denorm1");
	test_conversion<floatmap_lossy_add1>("lossy_add1");
	
	return 0;
}
