pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中文乱码的解决方法
Nov 04 Python
python中bisect模块用法实例
Sep 25 Python
Python的Django框架可适配的各种数据库介绍
Jul 15 Python
Python上传package到Pypi(代码简单)
Feb 06 Python
Python编程实现正则删除命令功能
Aug 30 Python
python opencv检测目标颜色的实例讲解
Apr 02 Python
利用Python库Scapy解析pcap文件的方法
Jul 23 Python
Python定时任务工具之APScheduler使用方式
Jul 24 Python
python 函数嵌套及多函数共同运行知识点讲解
Mar 03 Python
python中pandas库中DataFrame对行和列的操作使用方法示例
Jun 14 Python
Python实现打包成库供别的模块调用
Jul 13 Python
Python Django模型详解
Oct 05 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
PHP5.5和之前的版本empty函数的不同之处
2014/06/13 PHP
php获取从百度、谷歌等搜索引擎进入网站关键词的方法
2015/07/08 PHP
Yii框架组件的事件机制原理与用法分析
2020/04/07 PHP
一步一步制作jquery插件Tabs实现过程
2010/07/06 Javascript
Extjs中通过Tree加载右侧TabPanel具体实现
2013/05/05 Javascript
iframe子页面获取父页面元素的方法
2013/11/05 Javascript
JS脚本defer的作用示例介绍
2014/01/02 Javascript
javascript实现随机读取数组的方法
2015/08/03 Javascript
window.open打开窗口被拦截的快速解决方法
2016/08/04 Javascript
js利用appendChild对标签进行排序的实现方法
2016/10/16 Javascript
解决同一页面中两个iframe互相调用jquery,js函数的方法
2016/12/12 Javascript
微信小程序 开发之顶部导航栏实例代码
2017/02/23 Javascript
vue.js的提示组件
2017/03/02 Javascript
JS实现异步上传压缩图片
2017/04/22 Javascript
Vue 2.5.2下axios + express 本地请求404的解决方法
2018/02/21 Javascript
vue-router判断页面未登录自动跳转到登录页的方法示例
2018/11/04 Javascript
微信小程序canvas.drawImage完全显示图片问题的解决
2018/11/30 Javascript
基于vue-cli3+typescript的tsx开发模板搭建过程分享
2020/02/28 Javascript
Vue获取微博授权URL代码实例
2020/11/04 Javascript
跟老齐学Python之集合的关系
2014/09/24 Python
在Django中创建URLconf相关的通用视图的方法
2015/07/20 Python
transform python环境快速配置方法
2018/09/27 Python
浅析Python+OpenCV使用摄像头追踪人脸面部血液变化实现脉搏评估
2019/10/17 Python
pytorch sampler对数据进行采样的实现
2019/12/31 Python
python 安装impala包步骤
2020/03/28 Python
Linux Interview Questions For software testers
2012/06/02 面试题
思想汇报范文
2013/11/04 职场文书
毕业生自荐信
2013/12/14 职场文书
电大会计学自我鉴定
2014/02/06 职场文书
后勤服务中心总经理工作职责
2014/03/03 职场文书
生育关怀行动实施方案
2014/03/26 职场文书
红色影片观后感
2015/06/18 职场文书
导游词之无锡华莱坞
2019/12/02 职场文书
GO语言字符串处理函数之处理Strings包
2022/04/14 Golang
MySQL索引 高效获取数据的数据结构
2022/05/02 MySQL
关于pytest结合csv模块实现csv格式的数据驱动问题
2022/05/30 Python