pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python计算方程式根的方法
May 07 Python
使用Python的PIL模块来进行图片对比
Feb 18 Python
Python实现随机生成手机号及正则验证手机号的方法
Apr 25 Python
django的csrf实现过程详解
Jul 26 Python
Python Django 命名空间模式的实现
Aug 09 Python
解决pycharm上的jupyter notebook端口被占用问题
Dec 17 Python
tensorflow之并行读入数据详解
Feb 05 Python
Python Numpy,mask图像的生成详解
Feb 19 Python
Python编程快速上手——选择性拷贝操作案例分析
Feb 28 Python
python爬虫学习笔记之Beautifulsoup模块用法详解
Apr 09 Python
Python基于jieba, wordcloud库生成中文词云
May 13 Python
Python代码覆盖率统计工具coverage.py用法详解
Nov 25 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
php中使用$_REQUEST需要注意的一个问题
2013/05/02 PHP
php+ajax实现无刷新的新闻留言系统
2020/12/21 PHP
php微信支付接口开发程序
2016/08/02 PHP
thinkphp 抓取网站的内容并且保存到本地的实例详解
2017/08/25 PHP
PHP结合Redis+MySQL实现冷热数据交换应用案例详解
2019/07/09 PHP
php和redis实现秒杀活动的流程
2019/07/17 PHP
基于jquery+thickbox仿校内登录注册框
2010/06/07 Javascript
JQuery扩展插件Validate 5添加自定义验证方法
2011/09/05 Javascript
jquery随机展示头像代码
2011/12/21 Javascript
jQuery 1.7.2中getAll方法的疑惑分析
2012/05/23 Javascript
纯js简单日历实现代码
2013/10/05 Javascript
运用jQuery定时器的原理实现banner图片切换
2014/10/22 Javascript
Jquery全选与反选点击执行一次的解决方案
2015/08/14 Javascript
nodejs入门教程四:URL相关模块用法分析
2017/04/24 NodeJs
Javascript实现base64的加密解密方法示例
2017/06/27 Javascript
vue项目国际化vue-i18n的安装使用教程
2018/03/14 Javascript
如何在wxml中直接写js代码(wxs)
2019/11/14 Javascript
vue3.0实现插件封装
2020/12/14 Vue.js
[51:17]完美世界DOTA2联赛循环赛Inki vs DeMonsTer 第二场 10月30日
2020/10/31 DOTA
Python2中文处理纪要的实现方法
2018/03/10 Python
Python中使用Counter进行字典创建以及key数量统计的方法
2018/07/06 Python
Python常见内置高效率函数用法示例
2018/07/31 Python
对python中的装包与解包实例详解
2019/08/24 Python
Python读写文件模式和文件对象方法实例详解
2019/09/17 Python
Python操作列表常用方法实例小结【创建、遍历、统计、切片等】
2019/10/25 Python
wxPython实现列表增删改查功能
2019/11/19 Python
python3 sleep 延时秒 毫秒实例
2020/05/04 Python
如何在 Matplotlib 中更改绘图背景的实现
2020/11/26 Python
CSS3 毛玻璃效果
2019/08/14 HTML / CSS
群众路线剖析材料
2014/02/02 职场文书
幼儿园校园小喇叭广播稿
2014/10/17 职场文书
门店店长岗位职责
2015/04/14 职场文书
道歉的话怎么说
2015/05/12 职场文书
幼儿园亲子活动感想
2015/08/07 职场文书
送给客户微信问候语!
2019/07/04 职场文书
幼师必备:幼儿园期末教师评语50条
2019/11/01 职场文书