Pytorch实验常用代码段汇总


Posted in Python onNovember 19, 2020

1. 大幅度提升 Pytorch 的训练速度

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

但加了这一行,似乎运行结果不一样了。

2. 把原有的记录文件加个后缀变为 .bak 文件,避免直接覆盖

# from co-teaching train codetxtfile = save_dir + "/" + model_str + "_%s.txt"%str(args.optimizer)  ## good job!
nowTime=datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if os.path.exists(txtfile):
  os.system('mv %s %s' % (txtfile, txtfile+".bak-%s" % nowTime)) # bakeup 备份文件

3. 计算 Accuracy 返回list, 调用函数时,直接提取值,而非提取list

# from co-teaching code but MixMatch_pytorch code also has itdef accuracy(logit, target, topk=(1,)):
  """Computes the precision@k for the specified values of k"""
  output = F.softmax(logit, dim=1) # but actually not need it 
  maxk = max(topk)
  batch_size = target.size(0)

  _, pred = output.topk(maxk, 1, True, True) # _, pred = logit.topk(maxk, 1, True, True)
  pred = pred.t()
  correct = pred.eq(target.view(1, -1).expand_as(pred))

  res = []
  for k in topk:
    correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
    res.append(correct_k.mul_(100.0 / batch_size)) # it seems this is a bug, when not all batch has same size, the mean of accuracy of each batch is not the mean of accu of all dataset
  return res

prec1, = accuracy(logit, labels, topk=(1,)) # , indicate tuple unpackage
prec1, prec5 = accuracy(logits, labels, topk=(1, 5))

4. 善于利用 logger 文件来记录每一个 epoch 的实验值

# from Pytorch_MixMatch codeclass Logger(object):
  '''Save training process to log file with simple plot function.'''
  def __init__(self, fpath, title=None, resume=False): 
    self.file = None
    self.resume = resume
    self.title = '' if title == None else title
    if fpath is not None:
      if resume: 
        self.file = open(fpath, 'r') 
        name = self.file.readline()
        self.names = name.rstrip().split('\t')
        self.numbers = {}
        for _, name in enumerate(self.names):
          self.numbers[name] = []

        for numbers in self.file:
          numbers = numbers.rstrip().split('\t')
          for i in range(0, len(numbers)):
            self.numbers[self.names[i]].append(numbers[i])
        self.file.close()
        self.file = open(fpath, 'a') 
      else:
        self.file = open(fpath, 'w')

  def set_names(self, names):
    if self.resume: 
      pass
    # initialize numbers as empty list
    self.numbers = {}
    self.names = names
    for _, name in enumerate(self.names):
      self.file.write(name)
      self.file.write('\t')
      self.numbers[name] = []
    self.file.write('\n')
    self.file.flush()


  def append(self, numbers):
    assert len(self.names) == len(numbers), 'Numbers do not match names'
    for index, num in enumerate(numbers):
      self.file.write("{0:.4f}".format(num))
      self.file.write('\t')
      self.numbers[self.names[index]].append(num)
    self.file.write('\n')
    self.file.flush()

  def plot(self, names=None):  
    names = self.names if names == None else names
    numbers = self.numbers
    for _, name in enumerate(names):
      x = np.arange(len(numbers[name]))
      plt.plot(x, np.asarray(numbers[name]))
    plt.legend([self.title + '(' + name + ')' for name in names])
    plt.grid(True)

  def close(self):
    if self.file is not None:
      self.file.close()
# usage
logger = Logger(new_folder+'/log_for_%s_WebVision1M.txt'%data_type, title=title)
logger.set_names(['epoch', 'val_acc', 'val_acc_ImageNet'])
for epoch in range(100):
  logger.append([epoch, val_acc, val_acc_ImageNet])
logger.close()

5. 利用 argparser 命令行工具来进行代码重构,使用不同参数适配不同数据集,不同优化方式,不同setting, 避免多个高度冗余的重复代码

# argparser 命令行工具有一个坑的地方是,无法设置 bool 变量, flag=FALSE, 然后会解释为 字符串,仍然当做 True

发现可以使用如下命令来进行修补,来自 ICML-19-SGC github 上代码

parser.add_argument('--test', action='store_true', default=False, help='inductive training.')

当命令行出现 test 字样时,则为 args.test = true

若未出现 test 字样,则为 args.test = false

6. 使用shell 变量来设置所使用的显卡, 便于利用shell 脚本进行程序的串行,从而挂起来跑。或者多开几个 screen 进行同一张卡上多个程序并行跑,充分利用显卡的内存。

命令行中使用如下语句,或者把语句写在 shell 脚本中 # 不要忘了 export

export CUDA_VISIBLE_DEVICES=1 #设置当前可用显卡为编号为1的显卡(从 0 开始编号),即不在 0 号上跑
export CUDA_VISIBlE_DEVICES=0,1 # 设置当前可用显卡为 0,1 显卡,当 0 用满后,就会自动使用 1 显卡

一般经验,即使多个程序并行跑时,即使显存完全足够,单个程序的速度也会变慢,这可能是由于还有 cpu 和内存的限制。

这里显存占用不是阻碍,应该主要看GPU 利用率(也就是计算单元的使用,如果达到了 99% 就说明程序过多了。)

使用 watch nvidia-smi 来监测每个程序当前是否在正常跑。

7. 使用 python 时间戳来保存并进行区别不同的 result 文件

参照自己很早之前写的 co-training 的代码

8. 把训练时 命令行窗口的 print 输出全部保存到一个 log 文件:(参照 DIEN)

mkdir dnn_save_path
mkdir dnn_best_model
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 &

并且使用如下命令 | tee 命令则可以同时保存到文件并且写到命令行输出:

python script/train.py train DIEN | tee train_dein2.log

9. git clone 可以用来下载 github 上的代码,更快。(由 DIEN 的下载)

git clone https://github.com/mouna99/dien.git 使用这个命令可以下载 github 上的代码库

10. (来自 DIEN ) 对于命令行参数不一定要使用 argparser 来读取,也可以直接使用 sys.argv 读取,不过这样的话,就无法指定关键字参数,只能使用位置参数。

### run.sh ###
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 &
#############

if __name__ == '__main__':
  if len(sys.argv) == 4:
    SEED = int(sys.argv[3]) # 0,1,2,3
  else:
    SEED = 3
  tf.set_random_seed(SEED)
  numpy.random.seed(SEED)
  random.seed(SEED)
  if sys.argv[1] == 'train':
    train(model_type=sys.argv[2], seed=SEED)
  elif sys.argv[1] == 'test':
    test(model_type=sys.argv[2], seed=SEED)
  else:
    print('do nothing...')

11.代码的一种逻辑:time_point 是一个参数变量,可以有两种方案来处理

一种直接在外面判断:

#适用于输出变量的个数不同的情况
if time_point:
A, B, C = f1(x, y, time_point=True)
else:

A, B = f1(x, y, time_point=False)
# 适用于输出变量个数和类型相同的情况
C, D = f2(x, y, time_point=time_point)

12. 写一个 shell 脚本文件来进行调节超参数, 来自 [NIPS-20 Grand]

mkdir cora
for num in $(seq 0 99) do
python train_grand.py --hidden 32 --lr 0.01 --patience 200 --seed $num --dropnode_rate 0.5 > cora/"$num".txt
done

13. 使用 或者 不使用 cuda 运行结果可能会不一样,有细微差别。

cuda 也有一个相关的随机数种子的参数,当不使用 cuda 时,这一个随机数种子没有起到作用,因此可能会得到不同的结果。

来自 NIPS-20 Grand (2020.11.18)的实验结果发现。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python机器学习理论与实战(二)决策树
Jan 19 Python
将字典转换为DataFrame并进行频次统计的方法
Apr 08 Python
Django框架实现的普通登录案例【使用POST方法】
May 15 Python
Python使用Beautiful Soup爬取豆瓣音乐排行榜过程解析
Aug 15 Python
python修改FTP服务器上的文件名
Sep 11 Python
PyCharm第一次安装及使用教程
Jan 08 Python
在 Pycharm 安装使用black的方法详解
Apr 02 Python
Python 判断时间是否在时间区间内的实例
May 16 Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 Python
详解python安装matplotlib库三种失败情况
Jul 28 Python
python爬虫爬取图片的简单代码
Jan 18 Python
Python Pandas模块实现数据的统计分析的方法
Jun 24 Python
Ubuntu配置Pytorch on Graph (PoG)环境过程图解
Nov 19 #Python
python基于pygame实现飞机大作战小游戏
Nov 19 #Python
Python numpy大矩阵运算内存不足如何解决
Nov 19 #Python
python3 os进行嵌套操作的实例讲解
Nov 19 #Python
如何创建一个Flask项目并进行简单配置
Nov 18 #Python
使用PyCharm官方中文语言包汉化PyCharm
Nov 18 #Python
Python web框架(django,flask)实现mysql数据库读写分离的示例
Nov 18 #Python
You might like
PHP中exec函数和shell_exec函数的区别
2014/08/20 PHP
destoon实现VIP排名一直在前面排序的方法
2014/08/21 PHP
php+xml编程之SimpleXML的应用实例
2015/01/24 PHP
php版微信公众平台接口开发之智能回复开发教程
2016/09/22 PHP
jQuery live
2009/05/15 Javascript
button没写type=button会导致点击时提交
2014/03/06 Javascript
javascript 寻找错误方法整理
2014/06/15 Javascript
用原生JS获取CLASS对象(很简单实用)
2014/10/15 Javascript
浅谈jQuery中height与width
2015/07/06 Javascript
简单的JS时钟实例讲解
2016/01/13 Javascript
JQuery查找子元素find()和遍历集合each的方法总结
2017/03/07 Javascript
Vue cli+mui 区域滚动的实例代码
2018/01/25 Javascript
如何获取vue单文件自身源码路径
2019/05/06 Javascript
关于element-ui的隐藏组件el-scrollbar的使用
2019/05/29 Javascript
解决vue bus.$emit触发第一次$on监听不到问题
2020/07/28 Javascript
vue3.0自定义指令(drectives)知识点总结
2020/12/27 Vue.js
用实例分析Python中method的参数传递过程
2015/04/02 Python
使用Python解析JSON数据的基本方法
2015/10/15 Python
python 创建弹出式菜单的实现代码
2017/07/11 Python
python连接mongodb集群方法详解
2020/02/13 Python
Python 解析pymysql模块操作数据库的方法
2020/02/18 Python
opencv+python实现均值滤波
2020/02/19 Python
python实现图像拼接功能
2020/03/23 Python
Python求凸包及多边形面积教程
2020/04/12 Python
属性与 @property 方法让你的python更高效
2020/09/21 Python
Lululemon加拿大官网:加拿大知名体育服装零售商
2019/04/12 全球购物
培训楼经理岗位责任制
2014/02/10 职场文书
争当四好少年演讲稿
2014/09/13 职场文书
不听老师话的万能检讨书
2014/10/04 职场文书
机关作风建设自查报告及整改措施
2014/10/21 职场文书
世界遗产导游词
2015/02/13 职场文书
党员廉洁自律个人总结
2015/02/13 职场文书
2015年政务公开工作总结
2015/05/19 职场文书
付款证明模板
2015/06/19 职场文书
SpringCloud中分析讲解Feign组件添加请求头有哪些坑梳理
2022/06/21 Java/Android
如何用H5实现好玩的2048小游戏
2022/07/23 HTML / CSS