Pytorch实验常用代码段汇总


Posted in Python onNovember 19, 2020

1. 大幅度提升 Pytorch 的训练速度

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

但加了这一行,似乎运行结果不一样了。

2. 把原有的记录文件加个后缀变为 .bak 文件,避免直接覆盖

# from co-teaching train codetxtfile = save_dir + "/" + model_str + "_%s.txt"%str(args.optimizer)  ## good job!
nowTime=datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if os.path.exists(txtfile):
  os.system('mv %s %s' % (txtfile, txtfile+".bak-%s" % nowTime)) # bakeup 备份文件

3. 计算 Accuracy 返回list, 调用函数时,直接提取值,而非提取list

# from co-teaching code but MixMatch_pytorch code also has itdef accuracy(logit, target, topk=(1,)):
  """Computes the precision@k for the specified values of k"""
  output = F.softmax(logit, dim=1) # but actually not need it 
  maxk = max(topk)
  batch_size = target.size(0)

  _, pred = output.topk(maxk, 1, True, True) # _, pred = logit.topk(maxk, 1, True, True)
  pred = pred.t()
  correct = pred.eq(target.view(1, -1).expand_as(pred))

  res = []
  for k in topk:
    correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
    res.append(correct_k.mul_(100.0 / batch_size)) # it seems this is a bug, when not all batch has same size, the mean of accuracy of each batch is not the mean of accu of all dataset
  return res

prec1, = accuracy(logit, labels, topk=(1,)) # , indicate tuple unpackage
prec1, prec5 = accuracy(logits, labels, topk=(1, 5))

4. 善于利用 logger 文件来记录每一个 epoch 的实验值

# from Pytorch_MixMatch codeclass Logger(object):
  '''Save training process to log file with simple plot function.'''
  def __init__(self, fpath, title=None, resume=False): 
    self.file = None
    self.resume = resume
    self.title = '' if title == None else title
    if fpath is not None:
      if resume: 
        self.file = open(fpath, 'r') 
        name = self.file.readline()
        self.names = name.rstrip().split('\t')
        self.numbers = {}
        for _, name in enumerate(self.names):
          self.numbers[name] = []

        for numbers in self.file:
          numbers = numbers.rstrip().split('\t')
          for i in range(0, len(numbers)):
            self.numbers[self.names[i]].append(numbers[i])
        self.file.close()
        self.file = open(fpath, 'a') 
      else:
        self.file = open(fpath, 'w')

  def set_names(self, names):
    if self.resume: 
      pass
    # initialize numbers as empty list
    self.numbers = {}
    self.names = names
    for _, name in enumerate(self.names):
      self.file.write(name)
      self.file.write('\t')
      self.numbers[name] = []
    self.file.write('\n')
    self.file.flush()


  def append(self, numbers):
    assert len(self.names) == len(numbers), 'Numbers do not match names'
    for index, num in enumerate(numbers):
      self.file.write("{0:.4f}".format(num))
      self.file.write('\t')
      self.numbers[self.names[index]].append(num)
    self.file.write('\n')
    self.file.flush()

  def plot(self, names=None):  
    names = self.names if names == None else names
    numbers = self.numbers
    for _, name in enumerate(names):
      x = np.arange(len(numbers[name]))
      plt.plot(x, np.asarray(numbers[name]))
    plt.legend([self.title + '(' + name + ')' for name in names])
    plt.grid(True)

  def close(self):
    if self.file is not None:
      self.file.close()
# usage
logger = Logger(new_folder+'/log_for_%s_WebVision1M.txt'%data_type, title=title)
logger.set_names(['epoch', 'val_acc', 'val_acc_ImageNet'])
for epoch in range(100):
  logger.append([epoch, val_acc, val_acc_ImageNet])
logger.close()

5. 利用 argparser 命令行工具来进行代码重构,使用不同参数适配不同数据集,不同优化方式,不同setting, 避免多个高度冗余的重复代码

# argparser 命令行工具有一个坑的地方是,无法设置 bool 变量, flag=FALSE, 然后会解释为 字符串,仍然当做 True

发现可以使用如下命令来进行修补,来自 ICML-19-SGC github 上代码

parser.add_argument('--test', action='store_true', default=False, help='inductive training.')

当命令行出现 test 字样时,则为 args.test = true

若未出现 test 字样,则为 args.test = false

6. 使用shell 变量来设置所使用的显卡, 便于利用shell 脚本进行程序的串行,从而挂起来跑。或者多开几个 screen 进行同一张卡上多个程序并行跑,充分利用显卡的内存。

命令行中使用如下语句,或者把语句写在 shell 脚本中 # 不要忘了 export

export CUDA_VISIBLE_DEVICES=1 #设置当前可用显卡为编号为1的显卡(从 0 开始编号),即不在 0 号上跑
export CUDA_VISIBlE_DEVICES=0,1 # 设置当前可用显卡为 0,1 显卡,当 0 用满后,就会自动使用 1 显卡

一般经验,即使多个程序并行跑时,即使显存完全足够,单个程序的速度也会变慢,这可能是由于还有 cpu 和内存的限制。

这里显存占用不是阻碍,应该主要看GPU 利用率(也就是计算单元的使用,如果达到了 99% 就说明程序过多了。)

使用 watch nvidia-smi 来监测每个程序当前是否在正常跑。

7. 使用 python 时间戳来保存并进行区别不同的 result 文件

参照自己很早之前写的 co-training 的代码

8. 把训练时 命令行窗口的 print 输出全部保存到一个 log 文件:(参照 DIEN)

mkdir dnn_save_path
mkdir dnn_best_model
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 &

并且使用如下命令 | tee 命令则可以同时保存到文件并且写到命令行输出:

python script/train.py train DIEN | tee train_dein2.log

9. git clone 可以用来下载 github 上的代码,更快。(由 DIEN 的下载)

git clone https://github.com/mouna99/dien.git 使用这个命令可以下载 github 上的代码库

10. (来自 DIEN ) 对于命令行参数不一定要使用 argparser 来读取,也可以直接使用 sys.argv 读取,不过这样的话,就无法指定关键字参数,只能使用位置参数。

### run.sh ###
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 &
#############

if __name__ == '__main__':
  if len(sys.argv) == 4:
    SEED = int(sys.argv[3]) # 0,1,2,3
  else:
    SEED = 3
  tf.set_random_seed(SEED)
  numpy.random.seed(SEED)
  random.seed(SEED)
  if sys.argv[1] == 'train':
    train(model_type=sys.argv[2], seed=SEED)
  elif sys.argv[1] == 'test':
    test(model_type=sys.argv[2], seed=SEED)
  else:
    print('do nothing...')

11.代码的一种逻辑:time_point 是一个参数变量,可以有两种方案来处理

一种直接在外面判断:

#适用于输出变量的个数不同的情况
if time_point:
A, B, C = f1(x, y, time_point=True)
else:

A, B = f1(x, y, time_point=False)
# 适用于输出变量个数和类型相同的情况
C, D = f2(x, y, time_point=time_point)

12. 写一个 shell 脚本文件来进行调节超参数, 来自 [NIPS-20 Grand]

mkdir cora
for num in $(seq 0 99) do
python train_grand.py --hidden 32 --lr 0.01 --patience 200 --seed $num --dropnode_rate 0.5 > cora/"$num".txt
done

13. 使用 或者 不使用 cuda 运行结果可能会不一样,有细微差别。

cuda 也有一个相关的随机数种子的参数,当不使用 cuda 时,这一个随机数种子没有起到作用,因此可能会得到不同的结果。

来自 NIPS-20 Grand (2020.11.18)的实验结果发现。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现全局变量的两个解决方法
Jul 03 Python
简单讲解Python中的字符串与字符串的输入输出
Mar 13 Python
Python进行数据提取的方法总结
Aug 22 Python
Python多线程中阻塞(join)与锁(Lock)使用误区解析
Apr 27 Python
对Python 数组的切片操作详解
Jul 02 Python
机器学习实战之knn算法pandas
Jun 22 Python
Python多叉树的构造及取出节点数据(treelib)的方法
Aug 09 Python
Python udp网络程序实现发送、接收数据功能示例
Dec 09 Python
基于pytorch padding=SAME的解决方式
Feb 18 Python
opencv 查找连通区域 最大面积实例
Jun 04 Python
Python字符串函数strip()原理及用法详解
Jul 23 Python
Python实现简单的猜单词小游戏
Oct 28 Python
Ubuntu配置Pytorch on Graph (PoG)环境过程图解
Nov 19 #Python
python基于pygame实现飞机大作战小游戏
Nov 19 #Python
Python numpy大矩阵运算内存不足如何解决
Nov 19 #Python
python3 os进行嵌套操作的实例讲解
Nov 19 #Python
如何创建一个Flask项目并进行简单配置
Nov 18 #Python
使用PyCharm官方中文语言包汉化PyCharm
Nov 18 #Python
Python web框架(django,flask)实现mysql数据库读写分离的示例
Nov 18 #Python
You might like
php防盗链的常用方法小结
2010/07/02 PHP
ThinkPHP模板输出display用法分析
2014/11/26 PHP
CentOS6.5 编译安装lnmp环境
2014/12/21 PHP
PHP生成指定随机字符串的简单实现方法
2015/04/01 PHP
在IE下:float属性会影响offsetTop的取值
2006/12/22 Javascript
extjs 学习笔记 四 带分页的grid
2009/10/20 Javascript
JSChart轻量级图形报表工具(内置函数中文参考)
2010/10/11 Javascript
jQuery aminate方法定位到页面具体位置
2013/12/26 Javascript
Node.js重新刷新session过期时间的方法
2016/02/04 Javascript
JS判断浏览器是否安装flash插件的简单方法
2016/09/13 Javascript
nodejs实例解析(输出hello world)
2017/01/03 NodeJs
JavaScript中Require调用js的实例分享
2017/10/27 Javascript
使用ajax的post同步执行(实现方法)
2017/12/21 Javascript
微信小程序中使用ECharts 异步加载数据实现图表功能
2018/07/13 Javascript
Vue2.X 通过AJAX动态更新数据
2018/07/17 Javascript
JS加密插件CryptoJS实现AES加密操作示例
2018/08/16 Javascript
关于element-ui的隐藏组件el-scrollbar的使用
2019/05/29 Javascript
JS对象属性的检测与获取操作实例分析
2020/03/17 Javascript
JavaScript中的全局属性与方法深入解析
2020/06/14 Javascript
Python中单、双下划线的区别总结
2017/12/01 Python
python查看模块安装位置的方法
2018/10/16 Python
python实现n个数中选出m个数的方法
2018/11/13 Python
python opencv将图片转为灰度图的方法示例
2019/07/31 Python
Python线程障碍对象Barrier原理详解
2019/12/02 Python
keras获得model中某一层的某一个Tensor的输出维度教程
2020/01/24 Python
Python 面向对象部分知识点小结
2020/03/09 Python
10个python爬虫入门基础代码实例 + 1个简单的python爬虫完整实例
2020/12/16 Python
Coccinelle官网:意大利的著名皮具品牌
2019/05/15 全球购物
乌克兰电子和家用电器商店:Foxtrot
2019/07/23 全球购物
办公室年终个人自我评价
2013/10/28 职场文书
房地产项目建议书
2014/03/12 职场文书
一体化教学实施方案
2014/05/10 职场文书
卖车协议书范例
2014/09/16 职场文书
代收款委托书范本
2014/10/01 职场文书
简单了解 MySQL 中相关的锁
2021/05/25 MySQL
Java数组详细介绍及相关工具类
2022/04/14 Java/Android