pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多进程multiprocessing.Pool类详解
Apr 27 Python
Python爬虫包BeautifulSoup学习实例(五)
Jun 17 Python
python实现飞机大战微信小游戏
Mar 21 Python
python中使用zip函数出现错误的原因
Sep 28 Python
Python GUI编程 文本弹窗的实例
Jun 11 Python
python中metaclass原理与用法详解
Jun 25 Python
python matplotlib拟合直线的实现
Nov 19 Python
基于Tensorflow:CPU性能分析
Feb 10 Python
基于SQLAlchemy实现操作MySQL并执行原生sql语句
Jun 10 Python
如何以Winsows Service方式运行JupyterLab
Aug 30 Python
详解Python 中的 defaultdict 数据类型
Feb 22 Python
解决TensorFlow训练模型及保存数量限制的问题
Mar 03 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
纯php打造的tab选项卡效果代码(不用js)
2010/12/29 PHP
PHP后端银联支付及退款实例代码
2017/06/23 PHP
PHP+Ajax实现上传文件进度条动态显示进度功能
2018/06/04 PHP
线路分流自动智能跳转代码,自动选择最快镜像网站(js)
2011/10/31 Javascript
JavaScript对象创建及继承原理实例解剖
2013/02/28 Javascript
IE下JS读取xml文件示例代码
2013/08/05 Javascript
js实现精确到秒的倒计时效果
2016/05/29 Javascript
JS代码实现根据时间变换页面背景效果
2016/06/16 Javascript
javascript 网页进度条简单实例
2017/02/22 Javascript
jQuery Json数据格式排版高亮插件json-viewer.js使用方法详解
2017/06/12 jQuery
浅谈ES6新增的数组方法和对象
2017/08/08 Javascript
微信小程序三级联动选择器使用方法
2020/05/19 Javascript
vue-cli中的babel配置文件.babelrc实例详解
2018/02/22 Javascript
微信小程序异步API为Promise简化异步编程的操作方法
2018/08/14 Javascript
微信小程序嵌入腾讯视频源过程详解
2019/08/08 Javascript
vue学习之Vue-Router用法实例分析
2020/01/06 Javascript
ES6使用新特性Proxy实现的数据绑定功能实例
2020/05/11 Javascript
[01:06:32]DOTA2上海特级锦标赛D组资格赛#1 EG VS VP第一局
2016/02/28 DOTA
Django的分页器实例(paginator)
2017/12/01 Python
Python中顺序表的实现简单代码分享
2018/01/09 Python
对Tensorflow中的变量初始化函数详解
2018/07/27 Python
python使用PIL模块获取图片像素点的方法
2019/01/08 Python
python3.6实现学生信息管理系统
2019/02/21 Python
python实现抽奖小程序
2020/04/15 Python
Django基础知识 URL路由系统详解
2019/07/18 Python
pycharm远程连接vagrant虚拟机中mariadb数据库
2020/06/05 Python
Python 高效编程技巧分享
2020/09/10 Python
HTML5中使用json对象的实例代码
2018/09/10 HTML / CSS
Omio俄罗斯:一次搜索公共汽车、火车和飞机的机票
2018/11/17 全球购物
linux面试题参考答案(4)
2013/01/28 面试题
大学生毕业求职的自我评价
2013/09/29 职场文书
综治维稳工作承诺书
2014/08/30 职场文书
护士自荐信怎么写
2015/03/06 职场文书
html form表单基础入门案例讲解
2021/07/15 HTML / CSS
MySQL实现配置主从复制项目实践
2022/03/31 MySQL
zabbix配置nginx监控的实现
2022/05/25 Servers