updates
This commit is contained in:
parent
2195bb0e89
commit
e1c407dab1
1
train.py
1
train.py
|
@ -178,6 +178,7 @@ def train(cfg,
|
|||
world_size=1, # number of nodes for distributed training
|
||||
rank=0) # distributed training node rank
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
model.yolo_layers = model.module.yolo_layers # move yolo layer indices to top level
|
||||
|
||||
# Dataset
|
||||
dataset = LoadImagesAndLabels(train_path,
|
||||
|
|
|
@ -334,13 +334,12 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
|
|||
def build_targets(model, targets):
|
||||
# targets = [image, class, x, y, w, h]
|
||||
iou_thres = model.hyp['iou_t'] # hyperparameter
|
||||
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
||||
model = model.module
|
||||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||
|
||||
nt = len(targets)
|
||||
txy, twh, tcls, tbox, indices, anchor_vec = [], [], [], [], [], []
|
||||
for i in model.yolo_layers:
|
||||
layer = model.module_list[i]
|
||||
layer = model.module.module_list[i] if multi_gpu else model.module_list[i]
|
||||
|
||||
# iou of targets-anchors
|
||||
t, a = targets, []
|
||||
|
|
Loading…
Reference in New Issue