Pytorch实验常用代码段汇总


Posted in Python onNovember 19, 2020

1. 大幅度提升 Pytorch 的训练速度

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

但加了这一行,似乎运行结果不一样了。

2. 把原有的记录文件加个后缀变为 .bak 文件,避免直接覆盖

# from co-teaching train codetxtfile = save_dir + "/" + model_str + "_%s.txt"%str(args.optimizer)  ## good job!
nowTime=datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if os.path.exists(txtfile):
  os.system('mv %s %s' % (txtfile, txtfile+".bak-%s" % nowTime)) # bakeup 备份文件

3. 计算 Accuracy 返回list, 调用函数时,直接提取值,而非提取list

# from co-teaching code but MixMatch_pytorch code also has itdef accuracy(logit, target, topk=(1,)):
  """Computes the precision@k for the specified values of k"""
  output = F.softmax(logit, dim=1) # but actually not need it 
  maxk = max(topk)
  batch_size = target.size(0)

  _, pred = output.topk(maxk, 1, True, True) # _, pred = logit.topk(maxk, 1, True, True)
  pred = pred.t()
  correct = pred.eq(target.view(1, -1).expand_as(pred))

  res = []
  for k in topk:
    correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
    res.append(correct_k.mul_(100.0 / batch_size)) # it seems this is a bug, when not all batch has same size, the mean of accuracy of each batch is not the mean of accu of all dataset
  return res

prec1, = accuracy(logit, labels, topk=(1,)) # , indicate tuple unpackage
prec1, prec5 = accuracy(logits, labels, topk=(1, 5))

4. 善于利用 logger 文件来记录每一个 epoch 的实验值

# from Pytorch_MixMatch codeclass Logger(object):
  '''Save training process to log file with simple plot function.'''
  def __init__(self, fpath, title=None, resume=False): 
    self.file = None
    self.resume = resume
    self.title = '' if title == None else title
    if fpath is not None:
      if resume: 
        self.file = open(fpath, 'r') 
        name = self.file.readline()
        self.names = name.rstrip().split('\t')
        self.numbers = {}
        for _, name in enumerate(self.names):
          self.numbers[name] = []

        for numbers in self.file:
          numbers = numbers.rstrip().split('\t')
          for i in range(0, len(numbers)):
            self.numbers[self.names[i]].append(numbers[i])
        self.file.close()
        self.file = open(fpath, 'a') 
      else:
        self.file = open(fpath, 'w')

  def set_names(self, names):
    if self.resume: 
      pass
    # initialize numbers as empty list
    self.numbers = {}
    self.names = names
    for _, name in enumerate(self.names):
      self.file.write(name)
      self.file.write('\t')
      self.numbers[name] = []
    self.file.write('\n')
    self.file.flush()


  def append(self, numbers):
    assert len(self.names) == len(numbers), 'Numbers do not match names'
    for index, num in enumerate(numbers):
      self.file.write("{0:.4f}".format(num))
      self.file.write('\t')
      self.numbers[self.names[index]].append(num)
    self.file.write('\n')
    self.file.flush()

  def plot(self, names=None):  
    names = self.names if names == None else names
    numbers = self.numbers
    for _, name in enumerate(names):
      x = np.arange(len(numbers[name]))
      plt.plot(x, np.asarray(numbers[name]))
    plt.legend([self.title + '(' + name + ')' for name in names])
    plt.grid(True)

  def close(self):
    if self.file is not None:
      self.file.close()
# usage
logger = Logger(new_folder+'/log_for_%s_WebVision1M.txt'%data_type, title=title)
logger.set_names(['epoch', 'val_acc', 'val_acc_ImageNet'])
for epoch in range(100):
  logger.append([epoch, val_acc, val_acc_ImageNet])
logger.close()

5. 利用 argparser 命令行工具来进行代码重构,使用不同参数适配不同数据集,不同优化方式,不同setting, 避免多个高度冗余的重复代码

# argparser 命令行工具有一个坑的地方是,无法设置 bool 变量, flag=FALSE, 然后会解释为 字符串,仍然当做 True

发现可以使用如下命令来进行修补,来自 ICML-19-SGC github 上代码

parser.add_argument('--test', action='store_true', default=False, help='inductive training.')

当命令行出现 test 字样时,则为 args.test = true

若未出现 test 字样,则为 args.test = false

6. 使用shell 变量来设置所使用的显卡, 便于利用shell 脚本进行程序的串行,从而挂起来跑。或者多开几个 screen 进行同一张卡上多个程序并行跑,充分利用显卡的内存。

命令行中使用如下语句,或者把语句写在 shell 脚本中 # 不要忘了 export

export CUDA_VISIBLE_DEVICES=1 #设置当前可用显卡为编号为1的显卡(从 0 开始编号),即不在 0 号上跑
export CUDA_VISIBlE_DEVICES=0,1 # 设置当前可用显卡为 0,1 显卡,当 0 用满后,就会自动使用 1 显卡

一般经验,即使多个程序并行跑时,即使显存完全足够,单个程序的速度也会变慢,这可能是由于还有 cpu 和内存的限制。

这里显存占用不是阻碍,应该主要看GPU 利用率(也就是计算单元的使用,如果达到了 99% 就说明程序过多了。)

使用 watch nvidia-smi 来监测每个程序当前是否在正常跑。

7. 使用 python 时间戳来保存并进行区别不同的 result 文件

参照自己很早之前写的 co-training 的代码

8. 把训练时 命令行窗口的 print 输出全部保存到一个 log 文件:(参照 DIEN)

mkdir dnn_save_path
mkdir dnn_best_model
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 &

并且使用如下命令 | tee 命令则可以同时保存到文件并且写到命令行输出:

python script/train.py train DIEN | tee train_dein2.log

9. git clone 可以用来下载 github 上的代码,更快。(由 DIEN 的下载)

git clone https://github.com/mouna99/dien.git 使用这个命令可以下载 github 上的代码库

10. (来自 DIEN ) 对于命令行参数不一定要使用 argparser 来读取,也可以直接使用 sys.argv 读取,不过这样的话,就无法指定关键字参数,只能使用位置参数。

### run.sh ###
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 &
#############

if __name__ == '__main__':
  if len(sys.argv) == 4:
    SEED = int(sys.argv[3]) # 0,1,2,3
  else:
    SEED = 3
  tf.set_random_seed(SEED)
  numpy.random.seed(SEED)
  random.seed(SEED)
  if sys.argv[1] == 'train':
    train(model_type=sys.argv[2], seed=SEED)
  elif sys.argv[1] == 'test':
    test(model_type=sys.argv[2], seed=SEED)
  else:
    print('do nothing...')

11.代码的一种逻辑:time_point 是一个参数变量,可以有两种方案来处理

一种直接在外面判断:

#适用于输出变量的个数不同的情况
if time_point:
A, B, C = f1(x, y, time_point=True)
else:

A, B = f1(x, y, time_point=False)
# 适用于输出变量个数和类型相同的情况
C, D = f2(x, y, time_point=time_point)

12. 写一个 shell 脚本文件来进行调节超参数, 来自 [NIPS-20 Grand]

mkdir cora
for num in $(seq 0 99) do
python train_grand.py --hidden 32 --lr 0.01 --patience 200 --seed $num --dropnode_rate 0.5 > cora/"$num".txt
done

13. 使用 或者 不使用 cuda 运行结果可能会不一样,有细微差别。

cuda 也有一个相关的随机数种子的参数,当不使用 cuda 时,这一个随机数种子没有起到作用,因此可能会得到不同的结果。

来自 NIPS-20 Grand (2020.11.18)的实验结果发现。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python用于url解码和中文解析的小脚本(python url decoder)
Aug 11 Python
跟老齐学Python之集合(set)
Sep 24 Python
使用DataFrame删除行和列的实例讲解
Apr 08 Python
python3.6的venv模块使用详解
Aug 01 Python
Python如何爬取实时变化的WebSocket数据的方法
Mar 09 Python
python 列表中[ ]中冒号‘:’的作用
Apr 30 Python
python gdal安装与简单使用
Aug 01 Python
Django使用Jinja2模板引擎的示例代码
Aug 09 Python
python 串行执行和并行执行实例
Apr 30 Python
python requests.get带header
May 05 Python
Python新手如何进行闭包时绑定变量操作
May 29 Python
python excel和yaml文件的读取封装
Jan 12 Python
Ubuntu配置Pytorch on Graph (PoG)环境过程图解
Nov 19 #Python
python基于pygame实现飞机大作战小游戏
Nov 19 #Python
Python numpy大矩阵运算内存不足如何解决
Nov 19 #Python
python3 os进行嵌套操作的实例讲解
Nov 19 #Python
如何创建一个Flask项目并进行简单配置
Nov 18 #Python
使用PyCharm官方中文语言包汉化PyCharm
Nov 18 #Python
Python web框架(django,flask)实现mysql数据库读写分离的示例
Nov 18 #Python
You might like
PHP的宝库目录--PEAR
2006/10/09 PHP
使用PHP制作新闻系统的思路
2006/10/09 PHP
php上传图片客户端和服务器端实现方法
2015/03/30 PHP
php readfile()修改文件上传大小设置
2017/08/11 PHP
PHP回调函数与匿名函数实例详解
2017/08/16 PHP
phpstorm 正则匹配删除空行、注释行(替换注释行为空行)
2018/01/21 PHP
ASP.NET中使用后端代码注册脚本 生成JQUERY-EASYUI的界面错位的解决方法
2010/06/12 Javascript
基于jquery的当鼠标滚轮到最底端继续加载新数据思路分享(多用于微博、空间、论坛 )
2011/10/10 Javascript
jquery移除button的inline onclick事件(已测试及兼容浏览器)
2013/01/25 Javascript
Js 时间函数getYear()的使用问题探讨
2013/04/01 Javascript
js实时获取系统当前时间实例代码
2013/06/28 Javascript
JavaScript中匿名、命名函数的性能测试
2014/09/04 Javascript
js实现(全选)多选按钮的方法【附实例】
2016/03/30 Javascript
jQuery mobile的header和footer在点击屏幕的时候消失的解决办法
2016/07/01 Javascript
javascript数据结构中栈的应用之符号平衡问题
2017/04/11 Javascript
Angular 通过注入 $location 获取与修改当前页面URL的实例
2017/05/31 Javascript
通过jquery toggleClass()属性制作文章段落更改背景颜色
2018/05/21 jQuery
实例分析Array.from(arr)与[...arr]到底有何不同
2019/04/09 Javascript
Vue实现将数据库中带html标签的内容输出(原始HTML(Raw HTML))
2019/10/28 Javascript
vue 解决文本框被键盘遮住的问题
2019/11/06 Javascript
python基于递归解决背包问题详解
2019/07/03 Python
Python3之外部文件调用Django程序操作model等文件实现方式
2020/04/07 Python
css3 实现圆形旋转倒计时
2018/02/24 HTML / CSS
Canvas高级路径操作之拖拽对象的实现
2019/08/05 HTML / CSS
酒店前厅员工辞职信
2014/01/08 职场文书
幼儿教师演讲稿
2014/05/06 职场文书
贫困证明书格式及范文
2014/10/15 职场文书
教师工作表现自我评价
2015/03/05 职场文书
个人总结格式范文
2015/03/09 职场文书
2019感恩宣传标语!
2019/07/05 职场文书
vue首次渲染全过程
2021/04/21 Vue.js
nginx部署多前端项目的几种方法
2021/05/25 Servers
使用react-virtualized实现图片动态高度长列表的问题
2021/05/28 Javascript
关于 Python json中load和loads区别
2021/11/07 Python
剑指Offer之Java算法习题精讲二叉树专项训练
2022/03/21 Java/Android
nginx实现多geoserver服务的负载均衡
2022/05/15 Servers