# loss functions
import jax.numpy as np
[docs]def mse(y_pred, y_true):
''' Description: mean-square-error loss
Args:
y_pred : value predicted by model
y_true : ground truth value
eps: some scalar
'''
return np.mean((y_pred - y_true)**2)
[docs]def cross_entropy(y_pred, y_true, eps=1e-9):
''' Description: cross entropy loss, y_pred is equivalent to logits and y_true to labels
Args:
y_pred : value predicted by model
y_true : ground truth value
eps: some scalar
'''
return - np.mean(y_true * np.log(y_pred + eps))