pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈编码,解码,乱码的问题
Dec 30 Python
Django 跨域请求处理的示例代码
May 02 Python
对Python中list的倒序索引和切片实例讲解
Nov 15 Python
Python音频操作工具PyAudio上手教程详解
Jun 26 Python
pybind11和numpy进行交互的方法
Jul 04 Python
Python学习笔记之While循环用法分析
Aug 14 Python
基于Django signals 信号作用及用法详解
Mar 28 Python
PyTorch中torch.tensor与torch.Tensor的区别详解
May 18 Python
python爬虫筛选工作实例讲解
Nov 23 Python
Python 使用dict实现switch的操作
Apr 07 Python
解决Python字典查找报Keyerror的问题
May 26 Python
详解PyTorch模型保存与加载
Apr 28 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
PHP实现微信网页授权开发教程
2016/01/19 PHP
js电信网通双线自动选择技巧
2008/11/18 Javascript
javascript-表格排序(降序/反序)实现介绍(附图)
2013/05/30 Javascript
php跨域调用json的例子
2013/11/13 Javascript
Jquery焦点与失去焦点示例应用
2014/06/10 Javascript
javascript实现瀑布流自适应遇到的问题及解决方案
2015/01/28 Javascript
JS实现定时自动关闭DIV层提示框的方法
2015/05/11 Javascript
jQuery 弹出层插件(推荐)
2016/05/24 Javascript
详解Javascript ES6中的箭头函数(Arrow Functions)
2016/08/24 Javascript
ajax接收后台数据在html页面显示
2017/02/19 Javascript
JS检测数组类型的方法小结
2017/03/14 Javascript
浅谈vue实现数据监听的函数 Object.defineProperty
2017/06/08 Javascript
vue实现商城购物车功能
2017/11/27 Javascript
在Vue 中实现循环渲染多个相同echarts图表
2020/07/20 Javascript
创建与框架无关的JavaScript插件
2020/12/01 Javascript
python+PyQT实现系统桌面时钟
2020/06/16 Python
python+opencv实现阈值分割
2018/12/26 Python
Python常用模块logging——日志输出功能(示例代码)
2019/11/20 Python
浅析python,PyCharm,Anaconda三者之间的关系
2019/11/27 Python
python两种获取剪贴板内容的方法
2020/11/06 Python
详解appium自动化测试工具(monitor、uiautomatorviewer)
2021/01/27 Python
html2canvas截图空白问题的解决
2020/03/24 HTML / CSS
详解HTML5常用的语义化标签
2019/09/27 HTML / CSS
End Clothing美国站:英国男士潮牌商城
2018/04/20 全球购物
迪卡侬比利时官网:Decathlon比利时
2019/12/28 全球购物
什么是SQL Server的确定性函数和不确定性函数
2016/08/04 面试题
掌上明珠Java程序员面试总结
2016/02/23 面试题
业务部门经理岗位职责
2014/02/23 职场文书
槐乡的孩子教学反思
2014/04/27 职场文书
车辆工程专业求职信
2014/04/28 职场文书
工作证明范本(2篇)
2014/09/14 职场文书
学习计划书怎么写
2014/09/15 职场文书
体检通知范文
2015/04/21 职场文书
Qt自定义Plot实现曲线绘制的详细过程
2021/11/02 Python
GO语言字符串处理函数之处理Strings包
2022/04/14 Golang
mysql如何查询连续记录
2022/05/11 MySQL