pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中用has_key()方法查找键是否存在的教程
May 21 Python
python获取一组汉字拼音首字母的方法
Jul 01 Python
python selenium UI自动化解决验证码的4种方法
Jan 05 Python
Python实现的FTP通信客户端与服务器端功能示例
Mar 28 Python
对python列表里的字典元素去重方法详解
Jan 21 Python
浅谈python3中input输入的使用
Aug 02 Python
Python IDE Pycharm中的快捷键列表用法
Aug 08 Python
python虚拟环境模块venv使用及示例
Mar 04 Python
Python3 pickle对象串行化代码实例解析
Mar 23 Python
通过实例解析Python文件操作实现步骤
Sep 21 Python
Python 流媒体播放器的实现(基于VLC)
Apr 28 Python
Django模型层实现多表关系创建和多表操作
Jul 21 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
世界咖啡生产者论坛呼吁:需要立即就咖啡价格采取认真行动
2021/03/06 咖啡文化
php封装的smartyBC类完整实例
2016/10/19 PHP
php reset() 函数指针指向数组中的第一个元素并输出实例代码
2016/11/21 PHP
Laravel 中使用 Vue.js 实现基于 Ajax 的表单提交错误验证操作
2017/06/30 PHP
ThinkPHP5.0框架实现切换数据库的方法分析
2019/10/30 PHP
基于jquery+thickbox仿校内登录注册框
2010/06/07 Javascript
JS网页播放声音实现代码兼容各种浏览器
2013/09/22 Javascript
jquery实现全选功能效果的实现代码
2016/05/05 Javascript
微信支付 JS API支付接口详解
2016/07/11 Javascript
基于jQuery实现仿微博发布框字数提示
2016/07/27 Javascript
微信小程序(六):列表上拉加载下拉刷新示例
2017/01/13 Javascript
ES6中Generator与异步操作实例分析
2017/03/31 Javascript
jQuery Validate 校验多个相同name的方法
2017/05/18 jQuery
vue实现nav导航栏的方法
2017/12/13 Javascript
使用 Vue 绑定单个或多个 Class 名的实例代码
2018/01/08 Javascript
图解javascript作用域链
2019/05/27 Javascript
js实现div色块碰撞
2020/01/16 Javascript
Javascript执行上下文顺序的深入讲解
2020/11/04 Javascript
python 七种邮件内容发送方法实例
2014/04/22 Python
Python序列化基础知识(json/pickle)
2017/10/19 Python
Django接收post前端返回的json格式数据代码实现
2019/07/31 Python
基于python+selenium的二次封装的实现
2020/01/06 Python
TensorFlow内存管理bfc算法实例
2020/02/03 Python
Python爬取新型冠状病毒“谣言”新闻进行数据分析
2020/02/16 Python
Django 5种类型Session使用方法解析
2020/04/29 Python
CSS3 选择器 伪类选择器介绍
2012/01/21 HTML / CSS
汽车电子与维修专业大学生求职信
2013/09/28 职场文书
信息专业本科生个人的自我评价
2013/10/28 职场文书
营销总经理的岗位职责
2013/12/15 职场文书
红色故事演讲稿
2014/05/22 职场文书
保护环境标语
2014/06/09 职场文书
2014副局长群众路线对照检查材料思想汇报
2014/09/22 职场文书
12.4全国法制宣传日活动方案
2014/11/02 职场文书
解除劳动合同证明书模板
2014/11/20 职场文书
办公室日常管理制度
2015/08/04 职场文书
CSS 鼠标点击拖拽效果的实现代码
2022/12/24 HTML / CSS