【yolov5检测代码简化】Yolov5 detect.py推理代码简化,输入图片,输出图片和结果
最近的项目里有yolov5的嵌入,需求是只需要推理,模型文件是已有的,输入需要是图片(原yolov5是输入路径),输出结果的图片和标签。这样的话需要对原来的代码进行一些简化和变更。
·
前言
最近的项目里有yolov5的嵌入,需求是只需要推理,模型文件是已有的,输入需要是图片(原yolov5是输入路径),输出结果的图片和标签。这样的话需要对原来的代码进行一些简化和变更。
路径
模型这块路径的结构我是这样的,因为项目里还会加别的模型,所以有点套娃
加入__init__.py是为了让文件夹可以被当成包来import
models和utils是yolo原来的文件,至于需要哪个不需要哪个大家可以自己看着弄
import torch
import numpy as np
from model.yolov5.utils.augmentations import letterbox
from model.yolov5.models.experimental import attempt_load
from model.yolov5.utils.general import check_img_size, non_max_suppression, scale_coords, set_logging, \
xyxy2xywh
from model.yolov5.utils.torch_utils import select_device, time_sync
from model.yolov5.utils.plots import Annotator, colors
class Yolov5:
device = ''
weights = 'model/yolov5/screw.pt' # model.pt path(s)
imgsz = 640 # inference size (pixels)
save_img = True
def __init__(self):
set_logging()
self.device = select_device(self.device)
# Load model
w = str(self.weights[0] if isinstance(self.weights, list) else self.weights)
self.stride, self.names = 64, [f'class{i}' for i in range(1000)] # assign defaults
self.model = torch.jit.load(w) if 'torchscript' in w else attempt_load(self.weights, map_location=self.device)
self.stride = int(self.model.stride.max()) # model stride
self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names # get class names
self.imgsz = [self.imgsz]
self.imgsz *= 2
self.imgsz = check_img_size(self.imgsz, s=self.stride) # check image size
if self.device.type != 'cpu':
self.model(torch.zeros(1, 3, *self.imgsz, ).to(self.device).type_as(next(self.model.parameters()))) # run once
@torch.no_grad()
def run(self, im0s, # HWC图片
imgsz=640, # inference size (pixels)
conf_thres=0.25, # confidence threshold
iou_thres=0.45, # NMS IOU threshold
max_det=1000, # maximum detections per image
view_img=False, # show results
save_txt=False, # save results to *.txt
save_conf=False, # save confidences in --save-txt labels
save_crop=False, # save cropped prediction boxes
classes=None, # filter by class: --class 0, or --class 0 2 3
agnostic_nms=False, # class-agnostic NMS
augment=False, # augmented inference
line_thickness=3, # bounding box thickness (pixels)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
):
# Load image
assert im0s is not None, 'Image Not Available '
img = letterbox(im0s, imgsz, stride=self.stride)[0]
# Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
dt = [0.0, 0.0, 0.0]
t1 = time_sync()
img = torch.from_numpy(img).to(self.device)
img = img.float() # uint8 to fp16/32
img = img / 255.0 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3:
img = img[None] # expand for batch dim
t2 = time_sync()
dt[0] += t2 - t1
# Inference
visualize = False
pred = self.model(img, augment=augment, visualize=visualize)[0]
t3 = time_sync()
dt[1] += t3 - t2
# NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
dt[2] += time_sync() - t3
# Process predictions
result = []
result_lb = 0
det = pred[0] # per image
s = ''
im0 = im0s.copy()
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(self.names))
if len(det):
print('det', det)
result_lb = 1
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
# Write results
for *xyxy, conf, cls in reversed(det):
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
if self.save_img or save_crop or view_img: # Add bbox to image
c = int(cls) # integer class
label = None if hide_labels else (self.names[c] if hide_conf else f'{self.names[c]} {conf:.2f}')
annotator.box_label(xyxy, label, color=colors(c, True))
im_result = annotator.result()
result.append(im_result) # 图片
result.append(result_lb)
# Print time (inference-only)
print(f'{s}Done. ({t3 - t2:.3f}s)')
return result # [图片, 标签]
有几个点说一下,
现在这里device不手动指定,默认选择;
读取的模型文件一定是.pt,这个就是根据我自己这边定制的,原来支持各种格式,我把那些都删掉了简化;
输入图片是BGR也就是opencv读入的默认格式,因为后面有转化的代码;
imgsz输入一个int,后面会转化成yolo需要的格式(之前在这里debug好久);
我这里是做成一个类,实例化的时候会加载模型并且运行一次warming up以减少后面的初次运行时间(我猜的)(原作者就是这样做的,正式推理之前先run once)
需要检测的时候执行run就可以了,返回值是一个图片和标签组成的列表,当然用元组也可以
代码还是有一些冗余的我懒得再仔细看了。
更多推荐
已为社区贡献1条内容
所有评论(0)