updates
This commit is contained in:
parent
88804cad3b
commit
8b9aae484b
13
detect.py
13
detect.py
|
@ -41,17 +41,6 @@ def detect(
|
|||
else: # darknet format
|
||||
load_weights(model, weights_file_path)
|
||||
|
||||
# current = model.state_dict()
|
||||
# saved = checkpoint['model']
|
||||
# # 1. filter out unnecessary keys
|
||||
# saved = {k: v for k, v in saved.items() if ((k in current) and (current[k].shape == v.shape))}
|
||||
# # 2. overwrite entries in the existing state dict
|
||||
# current.update(saved)
|
||||
# # 3. load the new state dict
|
||||
# model.load_state_dict(current)
|
||||
# model.to(device).eval()
|
||||
# del checkpoint, current, saved
|
||||
|
||||
model.to(device).eval()
|
||||
|
||||
# Set Dataloader
|
||||
|
@ -69,7 +58,7 @@ def detect(
|
|||
# cv2.imwrite('zidane_416.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # letterboxed
|
||||
img = torch.from_numpy(img).unsqueeze(0).to(device)
|
||||
if ONNX_EXPORT:
|
||||
pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True);
|
||||
pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True)
|
||||
return # ONNX export
|
||||
pred = model(img)
|
||||
pred = pred[pred[:, :, 4] > conf_thres]
|
||||
|
|
Loading…
Reference in New Issue