pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python类继承与子类实例初始化用法分析
Apr 17 Python
python实现机械分词之逆向最大匹配算法代码示例
Dec 13 Python
解决pandas .to_excel不覆盖已有sheet的问题
Dec 10 Python
python添加菜单图文讲解
Jun 04 Python
python cumsum函数的具体使用
Jul 29 Python
python使用 request 发送表单数据操作示例
Sep 25 Python
关于sys.stdout和print的区别详解
Dec 05 Python
Python selenium 加载并保存QQ群成员,去除其群主、管理员信息的示例代码
May 28 Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 Python
Python析构函数__del__定义原理解析
Nov 20 Python
Python如何利用pandas读取csv数据并绘图
Jul 07 Python
Python使用pandas导入xlsx格式的excel文件内容操作代码
Dec 24 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
php中curl、fsocket、file_get_content三个函数的使用比较
2014/05/09 PHP
PHP二进制与字符串之间的相互转换教程
2016/10/14 PHP
Centos 6.5系统下编译安装PHP 7.0.13的方法
2016/12/19 PHP
php 命名空间(namespace)原理与用法实例小结
2019/11/13 PHP
Chrome中JSON.parse的特殊实现
2011/01/12 Javascript
Function.prototype.call.apply结合用法分析示例
2013/07/03 Javascript
javascript 原型链维护和继承详解
2014/11/26 Javascript
jQuery调用ajax请求的常见方法汇总
2015/03/24 Javascript
JS代码随机生成姓名、手机号、身份证号、银行卡号
2016/04/27 Javascript
iscroll碰到Select无法选择下拉刷新的解决办法
2016/05/21 Javascript
详解react内联样式使用webpack将px转rem
2018/09/13 Javascript
vue-lazyload使用总结(推荐)
2018/11/01 Javascript
小程序获取当前位置加搜索附近热门小区及商区的方法
2019/04/08 Javascript
jquery实现进度条状态展示
2020/03/26 jQuery
[03:12]完美世界DOTA2联赛PWL DAY7集锦
2020/11/06 DOTA
Python实现同时兼容老版和新版Socket协议的一个简单WebSocket服务器
2014/06/04 Python
Python实现的二维码生成小软件
2014/07/11 Python
Python实现八大排序算法
2016/08/13 Python
Python实现破解12306图片验证码的方法分析
2017/12/29 Python
Django中使用CORS实现跨域请求过程解析
2019/08/05 Python
学python安装的软件总结
2019/10/12 Python
PyCharm如何导入python项目的方法
2020/02/06 Python
Python openpyxl模块实现excel读写操作
2020/06/30 Python
python 利用toapi库自动生成api
2020/10/19 Python
亚历山大·王官网:Alexander Wang
2017/06/23 全球购物
美国第二大连锁书店:Books-A-Million
2017/12/28 全球购物
法国发饰品牌:Alexandre De Paris
2018/12/04 全球购物
这段代码难道不该打印出56吗
2013/02/27 面试题
大学生自我鉴定范文模板
2014/01/21 职场文书
小学数学教研活动总结
2014/07/01 职场文书
向国旗敬礼学生寄语大全
2014/09/30 职场文书
2015年党员个人工作总结
2015/05/13 职场文书
债务纠纷代理词
2015/05/25 职场文书
投诉信格式范文
2015/07/02 职场文书
2019朋友新婚祝福语精选
2019/10/10 职场文书
Windows Server 2022 超融合部署(图文教程)
2022/06/25 Servers