pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python定时器(Timer)用法简单实例
Jun 04 Python
Python读写unicode文件的方法
Jul 10 Python
Python多进程库multiprocessing中进程池Pool类的使用详解
Nov 24 Python
50行Python代码实现人脸检测功能
Jan 23 Python
python使用Matplotlib画条形图
Mar 25 Python
Django项目中添加ldap登陆认证功能的实现
Apr 04 Python
python的等深分箱实例
Nov 22 Python
matplotlib实现显示伪彩色图像及色度条
Dec 07 Python
keras .h5转移动端的.tflite文件实现方式
May 25 Python
利用PyQt5+Matplotlib 绘制静态/动态图的实现代码
Jul 13 Python
python 利用百度API识别图片文字(多线程版)
Dec 14 Python
彻底解决pip下载pytorch慢的问题方法
Mar 01 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
php中数据库连接方式pdo和mysqli对比分析
2015/02/25 PHP
解决php表单重复提交实现方法
2015/09/29 PHP
PHP实现的文件操作类及文件下载功能示例
2016/12/24 PHP
PHP实现登陆并抓取微信列表中最新一组微信消息的方法
2017/07/10 PHP
yii2中LinkPager增加总页数和总记录数的实例
2017/08/28 PHP
动态修改DOM 里面的 id 属性的弊端分析
2008/09/03 Javascript
JavaScript面向对象之体会[总结]
2008/11/13 Javascript
在网页里看flash的trace数据的js类
2009/01/10 Javascript
Javascript var变量隐式声明方法
2009/10/19 Javascript
20款超赞的jQuery插件 Web开发人员必备
2011/02/26 Javascript
JS对外部文件的加载及对IFRMAME的加载的实现,当加载完成后,指定指向方法(方法回调)
2011/07/04 Javascript
JS文本获得焦点清除文本文字的示例代码
2014/01/13 Javascript
Jquery Ajax方法传值到action的方法
2014/05/11 Javascript
js如何实现淡入淡出效果
2020/11/18 Javascript
每天一篇javascript学习小结(String对象)
2015/11/18 Javascript
Jquery $when done then的用法详解
2016/05/20 Javascript
jQuery实现对象转为url参数的方法
2017/01/11 Javascript
Bootstrap组合上、下拉框简单实现代码
2017/03/06 Javascript
微信小程序 按钮滑动的实现方法
2017/09/27 Javascript
React中this丢失的四种解决方法
2019/03/12 Javascript
微信小程序修改数组长度的问题的解决
2019/12/17 Javascript
浅谈vue权限管理实现及流程
2020/04/23 Javascript
Element中Slider滑块的具体使用
2020/07/29 Javascript
jQuery实现异步上传一个或多个文件
2020/08/17 jQuery
[04:13]2018国际邀请赛典藏宝瓶Ⅱ饰品一览
2018/07/21 DOTA
python 字符串追加实例
2019/07/20 Python
Python Django实现layui风格+django分页功能的例子
2019/08/29 Python
加拿大消费电子和手机购物网站:The Source
2017/01/28 全球购物
俄罗斯优惠券网站:BIGLION
2017/05/21 全球购物
教师网络培训感言
2014/03/09 职场文书
尊老爱幼演讲稿
2014/09/04 职场文书
加班费申请报告
2015/05/15 职场文书
教育教学读书笔记
2015/07/02 职场文书
先进个人主要事迹范文
2015/11/04 职场文书
Python实战之大鱼吃小鱼游戏的实现
2022/04/01 Python
vue-cli3.x配置全局的scss的时候报错问题及解决
2022/04/30 Vue.js