pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
线程和进程的区别及Python代码实例
Feb 04 Python
python处理按钮消息的实例详解
Jul 11 Python
Python调用ctypes使用C函数printf的方法
Aug 23 Python
Python3 循环语句(for、while、break、range等)
Nov 20 Python
python 计算数组中每个数字出现多少次--“Bucket”桶的思想
Dec 19 Python
Python3 中把txt数据文件读入到矩阵中的方法
Apr 27 Python
Python3中urlencode和urldecode的用法详解
Jul 23 Python
Python 日期的转换及计算的具体使用详解
Jan 16 Python
python GUI框架pyqt5 对图片进行流式布局的方法(瀑布流flowlayout)
Mar 12 Python
Python爬虫之Selenium实现关闭浏览器
Dec 04 Python
pycharm实现猜数游戏
Dec 07 Python
Python3接口性能测试实例代码
Jun 20 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
DOTA2 玩家自创拉野攻略 特色英雄快速成长篇
2020/04/20 DOTA
使用php+xslt在windows平台上
2006/10/09 PHP
PHP判断用户是否已经登录(跳转到不同页面或者执行不同动作)
2016/09/22 PHP
js变量作用域及可访问性的探讨
2006/11/23 Javascript
jQuery 技巧大全(新手入门篇)
2009/05/12 Javascript
jquery分页对象使用示例
2014/04/01 Javascript
js实现div弹出层的方法
2014/11/20 Javascript
高效的jquery数字滚动特效
2015/12/17 Javascript
AngularJS中使用HTML5手机摄像头拍照
2016/02/22 Javascript
Vue 2.0中生命周期与钩子函数的一些理解
2017/05/09 Javascript
vue.js实现条件渲染的实例代码
2017/06/22 Javascript
Three.js利用dat.GUI如何简化试验流程详解
2017/09/26 Javascript
如何理解Vue的v-model指令的使用方法
2018/07/19 Javascript
es6中reduce的基本使用方法
2019/09/10 Javascript
js实现无刷新监听URL的变化示例代码详解
2020/06/03 Javascript
python基础教程之python消息摘要算法使用示例
2014/02/10 Python
简单易懂的python环境安装教程
2017/07/13 Python
Python 限定函数参数的类型及默认值方式
2019/12/24 Python
使用keras实现densenet和Xception的模型融合
2020/05/23 Python
利用Storage Event实现页面间通信的示例代码
2018/07/26 HTML / CSS
三星印度官网:Samsung印度
2019/08/03 全球购物
应届生会计电算化求职信
2013/10/03 职场文书
物流专业毕业生推荐信范文
2013/11/18 职场文书
安全事故检讨书
2014/01/18 职场文书
周年庆典邀请函范文
2014/01/24 职场文书
上班离岗检讨书
2014/01/27 职场文书
幼儿园母亲节活动方案
2014/03/10 职场文书
教师中国梦演讲稿
2014/04/23 职场文书
企业安全生产月活动总结
2014/07/05 职场文书
赵乐秦在党的群众路线教育实践活动总结大会上的讲话稿
2014/10/25 职场文书
小学见习报告
2014/10/31 职场文书
故意伤害人身损害赔偿协议书
2014/11/19 职场文书
领导参观欢迎词
2015/01/26 职场文书
幼儿园中班教师个人总结
2015/02/05 职场文书
2016年大学生就业指导课心得体会
2015/10/09 职场文书
《海上日出》教学反思
2016/02/23 职场文书