pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
400多行Python代码实现了一个FTP服务器
May 10 Python
python中enumerate的用法实例解析
Aug 18 Python
python编写简单爬虫资料汇总
Mar 22 Python
Python查看微信撤回消息代码
Jun 07 Python
Python使用pymongo模块操作MongoDB的方法示例
Jul 20 Python
在python中安装basemap的教程
Sep 20 Python
不到40行代码用Python实现一个简单的推荐系统
May 10 Python
Python使用type关键字创建类步骤详解
Jul 23 Python
Python实现bilibili时间长度查询的示例代码
Jan 14 Python
python中wheel的用法整理
Jun 15 Python
使用tensorflow根据输入更改tensor shape
Jun 23 Python
Python matplotlib安装以及实现简单曲线的绘制
Apr 26 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
Windows下PHP5和Apache的安装与配置
2006/09/05 PHP
关于PHP递归算法和应用方法介绍
2013/04/15 PHP
sql注入与转义的php函数代码
2013/06/17 PHP
CentOS 6.2使用yum安装LAMP以及phpMyadmin详解
2013/06/17 PHP
php延迟静态绑定实例分析
2015/02/08 PHP
CodeIgniter框架常见用法工作总结
2017/03/16 PHP
Mootools 1.2教程 选项卡效果(Tabs)
2009/09/15 Javascript
javascript处理表单示例(javascript提交表单)
2014/04/28 Javascript
使用GruntJS构建Web程序之合并压缩篇
2014/06/06 Javascript
浅谈angular.js中实现双向绑定的方法$watch $digest $apply
2015/10/14 Javascript
Javascript简写条件语句(推荐)
2016/06/12 Javascript
基于JS实现移动端向左滑动出现删除按钮功能
2017/02/22 Javascript
Vue获取DOM元素样式和样式更改示例
2017/03/07 Javascript
Bootstrap + AngularJS 实现简单的数据过滤字符查找功能
2017/07/27 Javascript
jquery动态添加带有样式的HTML标签元素方法
2018/02/24 jQuery
vue权限管理系统的实现代码
2019/01/17 Javascript
如何利用vue+vue-router+elementUI实现简易通讯录
2019/05/13 Javascript
p5.js码绘“跳动的小正方形”的实现代码
2019/10/22 Javascript
python3实现磁盘空间监控
2018/06/21 Python
Python3实现腾讯云OCR识别
2018/11/27 Python
Python多线程原理与用法实例剖析
2019/01/22 Python
Python使用Beautiful Soup爬取豆瓣音乐排行榜过程解析
2019/08/15 Python
Python 基于wxpy库实现微信添加好友功能(简洁)
2019/11/29 Python
Python3创建Django项目的几种方法(3种)
2020/06/03 Python
java字符串格式化输出实例讲解
2021/01/06 Python
css3 响应式媒体查询的示例代码
2019/09/25 HTML / CSS
Spartoo美国:欧洲排名第一的在线时装零售商
2019/12/12 全球购物
团日活动总结书
2014/05/08 职场文书
孝敬父母的活动方案
2014/08/28 职场文书
趣味运动会广播稿
2014/09/13 职场文书
幼儿园教师师德师风承诺书
2015/04/28 职场文书
婚礼上证婚人致辞
2015/07/28 职场文书
信息技术国培研修日志
2015/11/13 职场文书
python简单验证码识别的实现过程
2021/06/20 Python
正则表达式基础与常用验证表达式
2022/06/16 Javascript
面试官问我Mysql的存储引擎了解多少
2022/08/05 MySQL