Deep Java Library(DJL)是一个高性能的开源深度学习框架,专门为Java开发者提供深度学习功能。DJL的主要特点包括:
<dependencies>
<!-- DJL 核心 API -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.26.0</version>
</dependency>
<!-- TensorFlow 引擎 -->
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-engine</artifactId>
<version>0.26.0</version>
</dependency>
<!-- 自动配置对应操作系统的原生库 -->
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-native-cpu</artifactId>
<classifier>win-x86_64</classifier>
<version>2.10.1</version>
<scope>runtime</scope>
</dependency>
<!-- 包含预训练模型的工具库 -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.26.0</version>
<scope>compile</scope>
</dependency>
</dependencies>
package com.pdtech.boot;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Paths;
public class DjlHelloAi {
public static void main(String[] args) throws IOException, ModelException, TranslateException {
// 1. 定义一张需要识别的本地图片路径(一张手写数字 7 的图片)
String imagePath = "src/main/resources/digit_seven.png";
Image img = ImageFactory.getInstance().fromFile(Paths.get(imagePath));
// 2. 定义 Translator:负责将 Java Image 转换为模型需要的张量(NDArray),并将输出转回文字
ImageClassificationTranslator translator = ImageClassificationTranslator.builder()
.addTransform(new ToTensor()) // 像素归一化到 0-1
.addTransform(new Normalize(new float[]{0.1307f}, new float[]{0.3081f})) // 标准化
.build();
// 3. 设置模型筛选条件(从云端 Model Zoo 自动加载 TensorFlow 的 MNIST 模型)
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optArtifactId("mnist") // 模型名称
.optEngine("TensorFlow") // 指定引擎
.optTranslator(translator)
.optProgress(new ProgressBar())
.build();
// 4. 加载模型并执行推理
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
// 获取结果
Classifications result = predictor.predict(img);
System.out.println("预测结果是: " + result.toString());
}
}
}