This commit is contained in:
Glenn Jocher 2019-08-05 17:25:50 +02:00
parent 2195bb0e89
commit e1c407dab1
2 changed files with 3 additions and 3 deletions

View File

@ -178,6 +178,7 @@ def train(cfg,
world_size=1, # number of nodes for distributed training world_size=1, # number of nodes for distributed training
rank=0) # distributed training node rank rank=0) # distributed training node rank
model = torch.nn.parallel.DistributedDataParallel(model) model = torch.nn.parallel.DistributedDataParallel(model)
model.yolo_layers = model.module.yolo_layers # move yolo layer indices to top level
# Dataset # Dataset
dataset = LoadImagesAndLabels(train_path, dataset = LoadImagesAndLabels(train_path,

View File

@ -334,13 +334,12 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
def build_targets(model, targets): def build_targets(model, targets):
# targets = [image, class, x, y, w, h] # targets = [image, class, x, y, w, h]
iou_thres = model.hyp['iou_t'] # hyperparameter iou_thres = model.hyp['iou_t'] # hyperparameter
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
model = model.module
nt = len(targets) nt = len(targets)
txy, twh, tcls, tbox, indices, anchor_vec = [], [], [], [], [], [] txy, twh, tcls, tbox, indices, anchor_vec = [], [], [], [], [], []
for i in model.yolo_layers: for i in model.yolo_layers:
layer = model.module_list[i] layer = model.module.module_list[i] if multi_gpu else model.module_list[i]
# iou of targets-anchors # iou of targets-anchors
t, a = targets, [] t, a = targets, []