diff --git a/utils/utils.py b/utils/utils.py index 59fe1a8c..2fb8c75d 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -259,7 +259,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # Define criteria MSE = nn.MSELoss() - CE = nn.CrossEntropyLoss(weight=model.class_weights) + CE = nn.CrossEntropyLoss() # (weight=model.class_weights) BCE = nn.BCEWithLogitsLoss() # Compute losses