pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python创建文件和追加文件内容实例
Oct 21 Python
详解Python中的相对导入和绝对导入
Jan 06 Python
Python与人工神经网络:使用神经网络识别手写图像介绍
Dec 19 Python
python 用所有标点符号分隔句子的示例
Jul 15 Python
用Anaconda安装本地python包的方法及路径问题(图文)
Jul 16 Python
Python generator生成器和yield表达式详解
Aug 08 Python
解决Python3用PIL的ImageFont输出中文乱码的问题
Aug 22 Python
Python多继承以及MRO顺序的使用
Nov 11 Python
如何使用python进行pdf文件分割
Nov 11 Python
python绘制规则网络图形实例
Dec 09 Python
python3中使用__slots__限定实例属性操作分析
Feb 14 Python
Python面向对象多态实现原理及代码实例
Sep 16 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
php md5下16位和32位的实现代码
2008/04/09 PHP
PHP遍历某个目录下的所有文件和子文件夹的实现代码
2013/06/28 PHP
php运行时动态创建函数的方法
2015/03/16 PHP
PHP+jQuery实现即点即改功能示例
2019/02/21 PHP
Thinkphp框架+Layui实现图片/文件上传功能分析
2020/02/07 PHP
改善你的jQuery的25个步骤 千倍级效率提升
2010/02/11 Javascript
jquery操作cookie插件分享
2014/01/14 Javascript
jQuery实现当按下回车键时绑定点击事件
2014/01/28 Javascript
node.js中的console.warn方法使用说明
2014/12/09 Javascript
使用基于Node.js的构建工具Grunt来发布ASP.NET MVC项目
2016/02/15 Javascript
AngularJS constant和value区别详解
2017/02/28 Javascript
利用Plupload.js解决大文件上传问题, 带进度条和背景遮罩层
2017/03/15 Javascript
js canvas实现QQ拨打电话特效
2017/05/10 Javascript
jQuery响应滚动条事件功能示例
2017/10/14 jQuery
vue.js系列中的vue-fontawesome使用
2018/02/10 Javascript
详解Nodejs内存治理
2018/05/13 NodeJs
JS与jQuery判断文本框还剩多少字符可以输入的方法
2018/09/01 jQuery
JS中通过url动态获取图片大小的方法小结(两种方法)
2018/10/31 Javascript
node.js通过Sequelize 连接MySQL的方法
2020/12/28 Javascript
vue实现树状表格效果
2020/12/29 Vue.js
获取python文件扩展名和文件名方法
2018/02/02 Python
numpy linalg模块的具体使用方法
2019/05/26 Python
python如何爬取网站数据并进行数据可视化
2019/07/08 Python
通过PHP与Python代码对比的语法差异详解
2019/07/10 Python
韩国休闲女装品牌网站:ANAIS
2016/08/24 全球购物
Linden Leaves官网:新西兰纯净护肤品
2020/12/20 全球购物
英国领先的高级美容和在线皮肤诊所:Face the Future
2020/06/17 全球购物
创建学习型党组织实施方案
2014/03/29 职场文书
产品质量承诺范本
2014/03/31 职场文书
艺术设计专业毕业生推荐信
2014/07/08 职场文书
初中国旗下的演讲稿
2014/08/28 职场文书
解除劳动合同协议书(样本)
2014/10/02 职场文书
用Python的绘图库(matplotlib)绘制小波能量谱
2021/04/17 Python
教你做个可爱的css滑动导航条
2021/06/15 HTML / CSS
浅谈Python数学建模之线性规划
2021/06/23 Python
Python3.8官网文档之类的基础语法阅读
2021/09/04 Python