1 是什么

Deep Java Library(DJL)是一个高性能的开源深度学习框架,专门为Java开发者提供深度学习功能。DJL的主要特点包括:

2 如何用

2.1 添加依赖

<dependencies>
    <!-- DJL 核心 API -->
    <dependency>
      <groupId>ai.djl</groupId>
      <artifactId>api</artifactId>
      <version>0.26.0</version>
    </dependency>
    <!-- TensorFlow 引擎 -->
    <dependency>
      <groupId>ai.djl.tensorflow</groupId>
      <artifactId>tensorflow-engine</artifactId>
      <version>0.26.0</version>
    </dependency>
    <!-- 自动配置对应操作系统的原生库 -->
    <dependency>
      <groupId>ai.djl.tensorflow</groupId>
      <artifactId>tensorflow-native-cpu</artifactId>
      <classifier>win-x86_64</classifier>
      <version>2.10.1</version>
      <scope>runtime</scope>
    </dependency>
    <!-- 包含预训练模型的工具库 -->
    <dependency>
      <groupId>ai.djl</groupId>
      <artifactId>model-zoo</artifactId>
      <version>0.26.0</version>
      <scope>compile</scope>
    </dependency>
</dependencies>

2.2 加载数据集

package com.pdtech.boot;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;

import java.io.IOException;
import java.nio.file.Paths;

public class DjlHelloAi {
    public static void main(String[] args) throws IOException, ModelException, TranslateException {
        // 1. 定义一张需要识别的本地图片路径(一张手写数字 7 的图片)
        String imagePath = "src/main/resources/digit_seven.png";
        Image img = ImageFactory.getInstance().fromFile(Paths.get(imagePath));

        // 2. 定义 Translator:负责将 Java Image 转换为模型需要的张量(NDArray),并将输出转回文字
        ImageClassificationTranslator translator = ImageClassificationTranslator.builder()
                .addTransform(new ToTensor()) // 像素归一化到 0-1
                .addTransform(new Normalize(new float[]{0.1307f}, new float[]{0.3081f})) // 标准化
                .build();

        // 3. 设置模型筛选条件(从云端 Model Zoo 自动加载 TensorFlow 的 MNIST 模型)
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optArtifactId("mnist") // 模型名称
                .optEngine("TensorFlow") // 指定引擎
                .optTranslator(translator)
                .optProgress(new ProgressBar())
                .build();

        // 4. 加载模型并执行推理
        try (ZooModel<Image, Classifications> model = criteria.loadModel();
             Predictor<Image, Classifications> predictor = model.newPredictor()) {
            
            // 获取结果
            Classifications result = predictor.predict(img);
            System.out.println("预测结果是: " + result.toString());
        }
    }
}