Java调用Pytorch实现以图搜图

设计技术栈:
1、ElasticSearch环境;
2、Python运行环境(如果事先没有pytorch模型时,可以用python脚本创建模型);

1、运行效果

2、创建模型(有则可以跳过)

vi script.py

import torch
import torch.nn as nn
import torchvision.models as modelsclass ImageFeatureExtractor(nn.Module):def __init__(self):super(ImageFeatureExtractor, self).__init__()self.resnet = models.resnet50(pretrained=True)#最终输出维度1024的向量,下文elastic search要设置dims为1024self.resnet.fc = nn.Linear(2048, 1024)def forward(self, x):x = self.resnet(x)return xif __name__ == '__main__':model = ImageFeatureExtractor()model.eval()#根据模型随便创建一个输入input = torch.rand([1, 3, 224, 224])output = model(input)#以这种方式保存script = torch.jit.trace(model, input)script.save("model.pt")

2、java项目pom.xml

<dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId><scope>provided</scope></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-engine</artifactId><version>0.19.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-native-cpu</artifactId><version>1.10.0</version><scope>runtime</scope></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-jni</artifactId><version>1.10.0-0.19.0</version></dependency><dependency><groupId>org.elasticsearch.client</groupId><artifactId>elasticsearch-rest-high-level-client</artifactId></dependency></dependencies>

3、ES创建文档

PUT /isi
{"mappings": {"properties": {"vector": {"type": "dense_vector","dims": 1024},"url" : {"type" : "keyword"},"user_id": {"type": "keyword"}}}
}

4、编写java代码调用模型

ORCUtil.java

package com.topprismcloud.rtm;import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xcontent.XContentType;import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.nio.file.Paths;
import java.util.*;public class ORCUtil {private static final String INDEX = "isi";private static final int IMAGE_SIZE = 224;private static Model model; // 模型private static Predictor<Image, float[]> predictor; // predictor.predict(input)相当于python中model(input)static {try {model = Model.newInstance("model");// 这里的model.pt是上面代码展示的那种方式保存的model.load(ORCUtil.class.getClassLoader().getResourceAsStream("model.pt"));Transform resize = new Resize(IMAGE_SIZE);Transform toTensor = new ToTensor();Transform normalize = new Normalize(new float[] { 0.485f, 0.456f, 0.406f },new float[] { 0.229f, 0.224f, 0.225f });// Translator处理输入Image转为tensor、输出转为float[]Translator<Image, float[]> translator = new Translator<Image, float[]>() {@Overridepublic NDList processInput(TranslatorContext ctx, Image input) throws Exception {NDManager ndManager = ctx.getNDManager();System.out.println("input: " + input.getWidth() + ", " + input.getHeight());NDArray transform = normalize.transform(toTensor.transform(resize.transform(input.toNDArray(ndManager))));System.out.println(transform.getShape());NDList list = new NDList();list.add(transform);return list;}@Overridepublic float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception {return ndList.get(0).toFloatArray();}};predictor = new Predictor<>(model, translator, Device.cpu(), true);} catch (Exception e) {e.printStackTrace();}}public static void upload() throws Exception {HttpHost host=new HttpHost("14.20.30.16", 9200, HttpHost.DEFAULT_SCHEME_NAME);RestClientBuilder builder=RestClient.builder(host);CredentialsProvider credentialsProvider = new BasicCredentialsProvider();credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("elastic", "123456"));builder.setHttpClientConfigCallback(f -> f.setDefaultCredentialsProvider(credentialsProvider));RestHighLevelClient client = new RestHighLevelClient( builder);// 批量上传请求BulkRequest bulkRequest = new BulkRequest(INDEX);File file = new File("D:\\001ENV\\nginx-1.24.0\\html\\resource\\new");for (File listFile : file.listFiles()) {
//          float[] vector = predictor.predict(ImageFactory.getInstance()
//                  .fromInputStream(Test.class.getClassLoader().getResourceAsStream("new/" + listFile.getName())));float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(new FileInputStream(listFile)));// 构建文档Map<String, Object> jsonMap = new HashMap<>();jsonMap.put("url", "/resource/"+listFile.getName());jsonMap.put("vector", vector);jsonMap.put("user_id", "user123");IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);bulkRequest.add(request);}client.bulk(bulkRequest, RequestOptions.DEFAULT);client.close();}// 接收待搜索图片的inputstream,搜索与其相似的图片public static List<SearchResult> search(InputStream input) throws Throwable {float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input));System.out.println(Arrays.toString(vector));// 展示k个结果int k = 100;// 连接Elasticsearch服务器RestHighLevelClient client = new RestHighLevelClient(RestClient.builder(new HttpHost("14.20.30.16", 9200, "http")));SearchRequest searchRequest = new SearchRequest(INDEX);Script script = new Script(ScriptType.INLINE, "painless", "cosineSimilarity(params.queryVector, doc['vector'])",Collections.singletonMap("queryVector", vector));FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders.functionScoreQuery(QueryBuilders.matchAllQuery(), ScoreFunctionBuilders.scriptFunction(script));SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();searchSourceBuilder.query(functionScoreQueryBuilder).fetchSource(null, "vector") // 不返回vector字段,太多了没用还耗时.size(k);searchRequest.source(searchSourceBuilder);SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);SearchHits hits = searchResponse.getHits();List<SearchResult> list = new ArrayList<>();for (SearchHit hit : hits) {// 处理搜索结果System.out.println(hit.toString());SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore());list.add(result);}client.close();return list;}public static void main(String[] args) throws Throwable {ORCUtil.upload();System.out.println("hao");}
}

SearchController.java

package com.topprismcloud.rtm;import java.util.List;import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;@RestController
@CrossOrigin
public class SearchController {@PostMapping("search")public ResponseEntity search(MultipartFile file) {try {List<SearchResult> list = ORCUtil.search(file.getInputStream());return ResponseEntity.ok(list);} catch (Throwable e) {return ResponseEntity.status(400).body(null);}}
}

SearchResult.java

package com.topprismcloud.rtm;import lombok.AllArgsConstructor;
import lombok.Data;@Data
@AllArgsConstructor
public class SearchResult {private String url;private Float score;
}

5、前端

index.html

<!DOCTYPE html>
<html lang="zh"><head><meta charset="UTF-8"><title>以图搜图</title><style>body {background: url("/img/bg.jpg");background-attachment: fixed;background-size: 100% 100%;}body>div {width: 1000px;margin: 50px auto;padding: 10px 20px;border: 1px solid lightgray;border-radius: 20px;box-sizing: border-box;background: rgba(255, 255, 255, 0.7);}.upload {display: inline-block;width: 300px;height: 280px;border: 1px dashed lightcoral;vertical-align: top;}.upload .cover {width: 200px;height: 200px;margin: 10px 50px;border: 1px solid black;box-sizing: border-box;text-align: center;line-height: 200px;position: relative;}.upload img {width: 198px;height: 198px;position: absolute;left: 0;top: 0;}.upload input {margin-left: 50px;}.upload button {width: 80px;height: 30px;margin-left: 110px;}.result-block {display: inline-block;margin-left: 40px;border: 1px solid lightgray;border-radius: 10px;min-height: 500px;width: 600px;}.result-block h1 {text-align: center;margin-top: 100px;}.result {padding: 10px;cursor: pointer;display: inline-block;}.result:hover {background: rgb(240, 240, 240);}.result p {width: 110px;overflow: hidden;white-space: nowrap;text-overflow: ellipsis;}.result img {width: 160px;height: 160px;}.result .prob {color: rgb(37, 147, 60)}</style><script src="js/jquery-3.6.0.js"></script>
</head><body><div><div class="upload"><div class="cover">请选择图片<img id="image" src="" /></div><input id="file" type="file"></div><div class="result-block"><h1>请选择图片</h1></div></div><ul id="box"></ul><script>var file = $('#file')file.change(function () {let f = this.files[0]let index = f.name.lastIndexOf('.')let fileText = f.name.substring(index, f.name.length)let ext = fileText.toLowerCase() //文件类型console.log(ext)if (ext != '.png' && ext != '.jpg' && ext != '.jpeg') {alert('系统仅支持 JPG、PNG、JPEG 格式的图片,请您调整格式后重新上传')return}$('.result-block').empty().append($('<h1>正在识别中...</h1>'))$("#image").attr("src", getObjectURL(f));let formData = new FormData()formData.append('file', f)$.ajax({url: 'http://10.1.2.240:8081/search',method: 'post',data: formData,processData: false,contentType: false,success: res => {console.log('shibie', res)$('.result-block').empty()for (let item of res) {console.log(item)let html = `<div class="result"><img src="${item.url}"/><div style="display: inline-block;vertical-align: top"><p class="prob">得分:${item.score.toFixed(4)}</p></div></div>`$('.result-block').append($(html))}}})});$('#button').click(function (e) {var file = $('#file')[0].files[0] //单个console.log(file)})function getObjectURL(file) {var url = null;if (window.createObjcectURL != undefined) {url = window.createOjcectURL(file);} else if (window.URL != undefined) {url = window.URL.createObjectURL(file);} else if (window.webkitURL != undefined) {url = window.webkitURL.createObjectURL(file);}return url;}function detect() {}</script>
</body></html>

6、打包后的源代码

以图搜图Java+html源代码

相关参考文章:Java调用Pytorch模型进行图像识别

Java调用Pytorch实现以图搜图(附源码)相关推荐

  1. Java+ElasticSearch+Pytorch实现以图搜图

    以图搜图,涉及两大功能:1.提取图像特征向量.2.相似向量检索. 第一个功能我通过编写pytorch模型并在java端借助djl调用实现,第二个功能通过elasticsearch7.6.2的dense ...

  2. Java毕设项目在线答题系统计算机(附源码+系统+数据库+LW)

    Java毕设项目在线答题系统计算机(附源码+系统+数据库+LW) 项目运行 环境配置: Jdk1.8 + Tomcat8.5 + Mysql + HBuilderX(Webstorm也行)+ Ecli ...

  3. JAVA计算机毕业设计校园订餐系统(附源码、数据库)

    JAVA计算机毕业设计校园订餐系统(附源码.数据库) 目运行 环境项配置: Jdk1.8 + Tomcat8.5 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(In ...

  4. JAVA计算机毕业设计网课系统(附源码、数据库)

    JAVA计算机毕业设计网课系统(附源码.数据库) 目运行 环境项配置: Jdk1.8 + Tomcat8.5 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(Inte ...

  5. JAVA计算机毕业设计漫画网站系统(附源码、数据库)

    JAVA计算机毕业设计漫画网站系统(附源码.数据库) 目运行 环境项配置: Jdk1.8 + Tomcat8.5 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(In ...

  6. JAVA计算机毕业设计快递代收系统(附源码、数据库)

    JAVA计算机毕业设计快递代收系统(附源码.数据库) 项目运行 环境配置: Jdk1.8 + Tomcat8.5 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(In ...

  7. java计算机毕业设计网络游戏后台管理系统(附源码、数据库)

    java计算机毕业设计网络游戏后台管理系统(附源码.数据库) 项目运行 环境配置: Jdk1.8 + Tomcat8.5 + Mysql + HBuilderX(Webstorm也行)+ Eclisp ...

  8. JAVA计算机毕业设计物料追溯系统(附源码、数据库)

    JAVA计算机毕业设计物料追溯系统(附源码.数据库) 目运行 环境项配置: Jdk1.8 + Tomcat8.5 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(In ...

  9. java计算机毕业设计购物网站设计(附源码、数据库)

    java计算机毕业设计购物网站设计(附源码.数据库) 项目运行 环境配置: Jdk1.8 + Tomcat8.5 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(In ...

最新文章

  1. 删了手机里的一个html文件,手机太卡,哪些内容可以毫不犹豫的删除?
  2. Python DAG—归简法—拓扑排序
  3. rds 数据导入mysql_将数据导入到 Amazon RDS 数据库实例
  4. 计算机的代表性产品,电脑展回顾 十款最具代表性存储产品
  5. kubernetes 数据_为什么数据科学家喜欢Kubernetes
  6. 0. 跟踪标记 (Trace Flag) 简介
  7. [转载] python的next()函数
  8. java ssh 学习_初学Java ssh之Spring 第三篇
  9. 办公搜索利器UTOOLS-基于EVERYTHING的文件快速搜索软件
  10. Visio2013激活/破解
  11. Linux之Shell编程详解
  12. sipP测试,UAS怎么主动发BYE消息
  13. 前端手册-CSS3 属性手册
  14. word替换向下箭头符号
  15. 平方米用计算机怎么计算公式,公式的换算和公式计算器
  16. 最小生成树-Prim + Kruskal算法
  17. Word/excel/df文档转图片返回前端
  18. centos7安装搜狗拼音
  19. 玉米社:百度SEM竞价推广策略有哪些?
  20. STM32F103调试笔记(1)——microusb接入电脑后显示未知USB设备(代码43)

热门文章

  1. 中望CAD的引线标注格式怎么改_国产操作系统生态有新进展!中望携手统信推出UOS版本CAD...
  2. 计算0到100之间的奇数和偶数的和
  3. python期中考试及答案_PYTHON期中考试试卷
  4. word文档计算机二级,计算机二级考试word
  5. 北京金桃科技有限公司(面试题)
  6. 内部转发和重定向的区别
  7. 制作U盘启动后容量减少的解决办法
  8. 使用FluentMybatis实现mybatis动态sql拼装和fluent api语法
  9. 内网映射代理方案(内网穿透)
  10. squid代理服务之透明代理的配置方法