Skip to content

Data type of loss functions #2338

@mrityunjay-tripathi

Description

@mrityunjay-tripathi

The loss functions in ann return double type loss only. Can we have the template for that as well so that users can set the precision whatever they want? In general, we don't use precision up to 17 decimal points for calculating loss, so using just float should work fine I think.

For example

template<typename InputDataType, typename OutputDataType>
template<typename InputType, typename TargetType>
typename InputType::value_type MyLossFunction<InputDataType, OutputDataType>::Forward(
    const InputType& input,
    TargetType& target)
{
  using RealType = typename InputType::value_type;
  
  // Use RealType wherever loss variable has to be dealt with.
}

What do you think?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions