pytorch 移动端部署之helloworld的使用


Posted in Python onOctober 30, 2020

开始

安装Androidstudio 4.1

克隆此项目

git clone https://github.com/pytorch/android-demo-app.git

使用androidstudio 打开 android-demo-app 中的HelloWordApp

打开之后androidstudio 会自动创建依赖 只需要等待即可

这个代码已经是官方写好的故而

开一下官方教程中的代码都在什么位置

这句

repositories {
  jcenter()
}

dependencies {
  implementation 'org.pytorch:pytorch_android:1.4.0'
  implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}

位置

HelloWorldApp\app\build.gradle

里面的全部代码

apply plugin: 'com.android.application'
repositories {
  jcenter()
}

android {
  compileSdkVersion 28
  buildToolsVersion "29.0.2"
  defaultConfig {
    applicationId "org.pytorch.helloworld"
    minSdkVersion 21
    targetSdkVersion 28
    versionCode 1
    versionName "1.0"
  }
  buildTypes {
    release {
      minifyEnabled false
    }
  }
}

dependencies {
  implementation 'androidx.appcompat:appcompat:1.1.0'
  implementation 'org.pytorch:pytorch_android:1.4.0'
  implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}

这句

Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
Module module = Module.load(assetFilePath(this, "model.pt"));
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
  TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
  Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
 if (scores[i] > maxScore) {
  maxScore = scores[i];
  maxScoreIdx = i;
 }
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

都在这里

HelloWorldApp\app\src\main\java\org\pytorch\helloworld\MainActivity.java

全部代码

package org.pytorch.helloworld;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;

import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

import androidx.appcompat.app.AppCompatActivity;

public class MainActivity extends AppCompatActivity {

 @Override
 protected void onCreate(Bundle savedInstanceState) {
  super.onCreate(savedInstanceState);
  setContentView(R.layout.activity_main);

  Bitmap bitmap = null;
  Module module = null;
  try {
   // creating bitmap from packaged into app android asset 'image.jpg',
   // app/src/main/assets/image.jpg
   bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
   // loading serialized torchscript module from packaged into app android asset model.pt,
   // app/src/model/assets/model.pt
   module = Module.load(assetFilePath(this, "model.pt"));
  } catch (IOException e) {
   Log.e("PytorchHelloWorld", "Error reading assets", e);
   finish();
  }

  // showing image on UI
  ImageView imageView = findViewById(R.id.image);
  imageView.setImageBitmap(bitmap);

  // preparing input tensor
  final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
    TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);

  // running the model
  final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

  // getting tensor content as java array of floats
  final float[] scores = outputTensor.getDataAsFloatArray();

  // searching for the index with maximum score
  float maxScore = -Float.MAX_VALUE;
  int maxScoreIdx = -1;
  for (int i = 0; i < scores.length; i++) {
   if (scores[i] > maxScore) {
    maxScore = scores[i];
    maxScoreIdx = i;
   }
  }

  String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

  // showing className on UI
  TextView textView = findViewById(R.id.text);
  textView.setText(className);
 }

 /**
  * Copies specified asset to the file in /files app directory and returns this file absolute path.
  *
  * @return absolute file path
  */
 public static String assetFilePath(Context context, String assetName) throws IOException {
  File file = new File(context.getFilesDir(), assetName);
  if (file.exists() && file.length() > 0) {
   return file.getAbsolutePath();
  }

  try (InputStream is = context.getAssets().open(assetName)) {
   try (OutputStream os = new FileOutputStream(file)) {
    byte[] buffer = new byte[4 * 1024];
    int read;
    while ((read = is.read(buffer)) != -1) {
     os.write(buffer, 0, read);
    }
    os.flush();
   }
   return file.getAbsolutePath();
  }
 }
}

在Build 中选择Build Bundile APK 的 Build APK 就可以了

生成的apk 在

HelloWorldApp\app\build\outputs\apk\debug

中 这个是可以直接安装的

安装后是一个固定的照片 就是检测了一个固定的照片

这是一个例子如果你只是想测试自己的模型调用能不能成功这个项目改改模型和模型加载即可

这个项目模型是一个resnet18 接着我们将其替换为resnet50

模型转换代码如下

import torch
import torchvision.models as models
from PIL import Image
import numpy as np
image = Image.open("test.jpg") #图片发在了build文件夹下
image = image.resize((224, 224),Image.ANTIALIAS)
image = np.asarray(image)
image = image / 255
image = torch.Tensor(image).unsqueeze_(dim=0)
image = image.permute((0, 3, 1, 2)).float()

model = models.resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
# output=resnet(torch.ones(1,3,224,224))
output = resnet(image)
max_index = torch.max(output, 1)[1].item()
print(max_index) # ImageNet1000类的类别序
resnet.save('model.pt')
if __name__ == '__main__':
  pass

将这个保存的模型 覆盖掉下面路径中的模型
 (在覆盖之前最好备份一个原来的模型,这里我们选择修改原来模型的名字为model_1.pt)

HelloWorldApp\app\src\main\assets\model.pt

成功覆盖后再一次执行打包操作(在Build 中选择Build Bundile APK 的 Build APK 就可以了
生成的apk 在
HelloWorldApp\app\build\outputs\apk\debug)
而后打开文件发现一个123M的apk 之前的apk是73M

安装 并且测试

完美打开也就是说一切resnet 系列的 都可以通过这个 项目进行演化出来

到此这篇关于pytorch 移动端部署之helloworld的使用的文章就介绍到这了,更多相关pytorch 移动端部署helloworld内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
简单介绍Python中用于求最小值的min()方法
May 15 Python
python中list常用操作实例详解
Jun 03 Python
基于Django的python验证码(实例讲解)
Oct 23 Python
利用scrapy将爬到的数据保存到mysql(防止重复)
Mar 31 Python
基于Django框架利用Ajax实现点赞功能实例代码
Aug 19 Python
Python面向对象程序设计之私有属性及私有方法示例
Apr 08 Python
Python3.5面向对象编程图文与实例详解
Apr 24 Python
详解python调用cmd命令三种方法
Jul 08 Python
Django使用Channels实现WebSocket的方法
Jul 28 Python
Python学习笔记之列表推导式实例分析
Aug 13 Python
Python实现密码薄文件读写操作
Dec 16 Python
python中slice参数过长的处理方法及实例
Dec 15 Python
把Anaconda中的环境导入到Pycharm里面的方法步骤
Oct 30 #Python
Python模拟登录和登录跳转的参考示例
Oct 30 #Python
python中watchdog文件监控与检测上传功能
Oct 30 #Python
GitHub上值得推荐的8个python 项目
Oct 30 #Python
python读取excel数据绘制简单曲线图的完整步骤记录
Oct 30 #Python
用python写PDF转换器的实现
Oct 29 #Python
python查询MySQL将数据写入Excel
Oct 29 #Python
You might like
xajax写的留言本
2006/11/25 PHP
php 数组的合并、拆分、区别取值函数集
2010/02/15 PHP
php中实现简单的ACL 完结篇
2011/09/07 PHP
php数组声明、遍历、数组全局变量使用小结
2013/06/05 PHP
php文件包含目录配置open_basedir的使用与性能详解
2017/04/03 PHP
php+mysql开发的最简单在线题库(在线做题系统)完整案例
2019/03/30 PHP
javascript 随机展示头像实现代码
2011/12/06 Javascript
jquery ui对话框实例代码
2013/05/10 Javascript
javascript实现日期按月份加减
2015/05/15 Javascript
node.js中fs.stat与fs.fstat的区别详解
2017/06/01 Javascript
微信小程序支付前端源码
2018/08/29 Javascript
Vue中使用方法、计算属性或观察者的方法实例详解
2018/10/31 Javascript
js前端面试之同步与异步问题详解
2019/04/03 Javascript
微信小程序npm引入vant-weapp的踩坑记录
2019/08/01 Javascript
Vue中登录验证成功后保存token,并每次请求携带并验证token操作
2020/09/08 Javascript
Java中重定向输出流实现用文件记录程序日志
2015/06/12 Python
Windows上配置Emacs来开发Python及用Python扩展Emacs
2015/11/20 Python
Python简单实现enum功能的方法
2016/04/25 Python
Ubuntu下使用Python实现游戏制作中的切分图片功能
2018/03/30 Python
使用python脚本实现查询火车票工具
2018/07/19 Python
Python使用gRPC传输协议教程
2018/10/16 Python
Django框架验证码用法实例分析
2019/05/10 Python
python list转置和前后反转的例子
2019/08/26 Python
python tqdm 实现滚动条不上下滚动代码(保持一行内滚动)
2020/02/19 Python
matplotlib bar()实现百分比堆积柱状图
2021/02/24 Python
美国著名的户外用品品牌:L.L.Bean
2018/01/05 全球购物
英国版MAC彩妆品牌:Illamasqua
2018/04/18 全球购物
澳大利亚在线性感内衣商店:Fantasy Lingerie
2021/02/07 全球购物
室内设计专业个人的自我评价
2013/12/18 职场文书
企业车辆管理制度
2014/01/24 职场文书
卫生安全检查制度
2014/02/04 职场文书
扩大国家免疫规划实施方案
2014/03/21 职场文书
技术支持岗位职责
2015/02/13 职场文书
2015年信访维稳工作总结
2015/04/07 职场文书
简述python四种分词工具,盘点哪个更好用?
2021/04/13 Python
postgresql如何找到表中重复数据的行并删除
2023/05/08 MySQL