pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python出现&quot;IndentationError: unexpected indent&quot;错误解决办法
Oct 15 Python
用Python进行简单图像识别(验证码)
Jan 19 Python
Python随机函数random()使用方法小结
Apr 29 Python
Python面向对象之反射/自省机制实例分析
Aug 24 Python
python统计中文字符数量的两种方法
Jan 31 Python
TensorFlow卷积神经网络之使用训练好的模型识别猫狗图片
Mar 14 Python
python进程和线程用法知识点总结
May 28 Python
Python+PyQT5的子线程更新UI界面的实例
Jun 14 Python
Python list与NumPy array 区分详解
Nov 06 Python
详解python中的异常和文件读写
Jan 03 Python
Python函数中的不定长参数相关知识总结
Jun 24 Python
Elasticsearch 索引操作和增删改查
Apr 19 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
php加水印的代码(支持半透明透明打水印,支持png透明背景)
2013/01/17 PHP
laravel通过创建自定义artisan make命令来新建类文件详解
2017/08/17 PHP
Thinkphp极验滑动验证码实现步骤解析
2020/11/24 PHP
基于jquery的修改当前TAB显示标题的代码
2010/12/11 Javascript
jquery中使用ajax获取远程页面信息
2011/11/13 Javascript
Jquery 获取checkbox的checked问题
2011/11/16 Javascript
JQuery球队选择实例
2015/05/18 Javascript
js禁止页面刷新与后退的方法
2015/06/08 Javascript
javascript实现在线客服效果
2015/07/15 Javascript
jqueryMobile使用示例分享
2016/01/12 Javascript
JS常用知识点整理
2017/01/21 Javascript
Three.js入门之hello world以及如何绘制线
2017/09/25 Javascript
react-native中ListView组件点击跳转的方法示例
2017/09/30 Javascript
JavaScript对JSON数组简单排序操作示例
2019/01/31 Javascript
layui实现三级联动效果
2019/07/26 Javascript
Vue实现仿iPhone悬浮球的示例代码
2020/03/13 Javascript
JavaScript DOM常用操作代码汇总
2020/07/03 Javascript
javascript的hashCode函数实现代码小结
2020/08/11 Javascript
深入Python解释器理解Python中的字节码
2015/04/01 Python
Python新手在作用域方面经常容易碰到的问题
2015/04/03 Python
Python Web框架Tornado运行和部署
2020/10/19 Python
python操作字典类型的常用方法(推荐)
2016/05/16 Python
python之PyMongo使用总结
2017/05/26 Python
python opencv实现旋转矩形框裁减功能
2018/07/25 Python
Python魔法方法详解
2019/02/13 Python
git查看、创建、删除、本地、远程分支方法详解
2020/02/18 Python
Python爬虫实现百度翻译功能过程详解
2020/05/29 Python
Python3获取cookie常用三种方案
2020/10/05 Python
局域网定义和特性
2016/01/23 面试题
打架检讨书500字
2014/01/29 职场文书
精彩的英文自荐信
2014/01/30 职场文书
书香校园活动方案
2014/02/28 职场文书
妇女儿童发展规划实施方案
2014/03/16 职场文书
政治学专业毕业生求职信
2014/08/11 职场文书
2014年政府采购工作总结
2014/12/09 职场文书
2016年基层党组织公开承诺书
2016/03/25 职场文书