Building custom loss functions
A loss function is a function that take as input a lux layer, the parameters and state, a named tuple containing the points where the model is evaluated, the preprocessed input and the algebric distance to the surface. The loss function then return 2 values:
- the loss : a scalar number
- the state of the model
- a nambed tuple containing evaluations metrics
function custom_loss(model,
ps,
st,
(; inputs,
d_reals))::Tuple{
Float32, Any, CategoricalMetric}
# model evaluation
v_pred, st = Lux.apply(model, inputs, ps, st)
v_pred = vcat(v_pred, 1 .- v_pred)
v_pred = cpu_device()(v_pred)
probabilities = ignore_derivatives() do
generate_true_probabilities(d_reals)
end
(KL(probabilities, v_pred) |> mean,
st, (;))
end
Once we have the loss function we need to register it in order to use in at the command line level.
First we need a type to represent the loss function.
struct CustomLoss <: LossType end
Then we need to give the type of metric used by the model. In our case it is a empty NamedTuple
.
_metric_type(::Type{CustomLoss}) = @NamedTuple{}
We need to associate the loss function to our new type.
get_loss_fn(::CustomLoss) = custom_loss
At the end we need to give the name that will be used at the command line level to select our loss.
_get_loss_type(::StaticSymbol{:custom}) = CustomLoss()