1.准备好模型文件和对象分类放到同一文件夹下

2.准备 pom文件

<properties>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
        <djl.version>0.15.0-SNAPSHOT</djl.version>
        <exec.mainClass>ai.djl.examples.inference.ObjectDetection</exec.mainClass>
    </properties>

    <repositories>
        <repository>
            <id>djl.ai</id>
            <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
        </repository>
    </repositories>
    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>${djl.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>commons-cli</groupId>
            <artifactId>commons-cli</artifactId>
            <version>1.4</version>
        </dependency>
        <dependency>
            <groupId>org.apache.logging.log4j</groupId>
            <artifactId>log4j-slf4j-impl</artifactId>
            <version>2.12.1</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>${djl.version}</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <scope>runtime</scope>
            <version>1.9.1</version>
        </dependency>
        <dependency>
            <groupId>org.testng</groupId>
            <artifactId>testng</artifactId>
            <version>6.8.1</version>
            <scope>test</scope>
        </dependency>
    </dependencies>

3.重写官方文件YoloV5Translator.java

修改这一部分即可,也可以在DetectedObjects在输出后进行修改,否则输出的图片无法画圈


    @PostMapping("/ObjectDetectionHbj")
    public void ObjectDetection(){
        DetectedObjects detection = null;
        try {
            detection = predict();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ModelException e) {
            e.printStackTrace();
        } catch (TranslateException e) {
            e.printStackTrace();
        }
        log.info("{}", detection);
    }

    public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
        Path imageFile = Paths.get("src/main/resources/kana1.jpg");
        Image img = ImageFactory.getInstance().fromFile(imageFile);
        Map<String, Object> arguments = new ConcurrentHashMap<>();
        arguments.put("width", 640);//图片以640宽度进行操作
        arguments.put("height", 640);//图片以640高度进行操作
        arguments.put("resize", true);//调整图片大小
        arguments.put("rescale", true);//图片值编程0-1之间
        //arguments.put("normalize", true);

    /*    arguments.put("toTensor", false);//转换成张量
        arguments.put("range", "0,1");//范围
        arguments.put("normalize", "false");//正态化*/
        //arguments.put("threshold", 0.2);//阈值小于0.2不显示
        //arguments.put("nmsThreshold", 0.5);

        //获取模型分类
        Translator<Image, DetectedObjects> translator = YoloV5Translator.builder(arguments).optSynsetArtifactName("coco.names").build();
        Criteria<Image, DetectedObjects> criteria =
                Criteria.builder()
                        .optApplication(Application.CV.INSTANCE_SEGMENTATION)
                        .setTypes(Image.class, DetectedObjects.class)
                        .optDevice(Device.cpu())
                        .optModelPath(Paths.get("D:\\work\\git\\model\\yolov5s\\"))
                        .optModelName("yolov5s.torchscript.pt") //获取模型
                        .optTranslator(translator)
                        .optProgress(new ProgressBar())
                        .optEngine("PyTorch")
                        .build();
        try (ZooModel<Image, DetectedObjects> model = criteria.loadModel()) {
            try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
                DetectedObjects detection = predictor.predict(img);
                saveBoundingBoxImage(img, detection);
                return detection;
            }
        }
    }

    /**
     * @Author bjiang
     * @Description //TODO 根据detection绘制图片,输出到 build/output
     * @Date 10:08 2021/12/31
     * @Version 1.0
     * @Param [img, detection]
     * @return void
     */
    private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
            throws IOException {
        Path outputDir = Paths.get("build/output");
        Files.createDirectories(outputDir);
        DetectedObjects detectionNew= DetectedObjectUtil.me().getDetectedObjects(detection);
        img.drawBoundingBoxes(detectionNew);
        Path imagePath = outputDir.resolve("instances.png");
        img.save(Files.newOutputStream(imagePath), "png");
        System.out.println("Segmentation result image has been saved in"+detectionNew);
    }
   private static DetectedObjectUtil instance;
    public static DetectedObjectUtil me() {
        if (instance == null) {
            instance = new DetectedObjectUtil();
        }
        return instance;
    }

    /**
     * @Author bjiang
     * @Description //TODO 重构detection,对象后增加可能性
     * @Date 10:06 2021/12/31
     * @Version 1.0
     * @Param [detection]
     * @return ai.djl.modality.cv.output.DetectedObjects
     */
    public DetectedObjects getDetectedObjects(DetectedObjects detection){
        List<String> className=new ArrayList<>();
        List<Double> probability=new ArrayList<>();
        List<BoundingBox> boundingBoxes=new ArrayList<>();
        for (DetectedObjects.DetectedObject obj : detection.<DetectedObjects.DetectedObject>items()) {
            BoundingBox bbox = obj.getBoundingBox();
            Rectangle rectangle = bbox.getBounds();
            className.add(obj.getClassName()+" " + obj.getProbability());
            probability.add(obj.getProbability());
            Rectangle rectangleNew=new Rectangle(rectangle.getX(),rectangle.getY(),
                    rectangle.getWidth(),rectangle.getHeight());
            boundingBoxes.add(rectangleNew);
        }
        DetectedObjects detectionNew=new DetectedObjects(className,probability,boundingBoxes);
        return detectionNew;
    }

5.执行接口ObjectDetectionHbj,控制台输出

[
	class: "person", probability: 0.71513, bounds: [x=0.175, y=0.158, width=0.775, height=0.826]
]

查看build/output目录

源码:https://gitee.com/bjiangAnhui/djl-boot.git

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐