horovod tensorflow 分布式多gpu_horovod sparse_as_dense-程序员宅基地

技术标签: DL tools  

概念
rank is your index within the entire ring, local_rank is your index within your node. For example, you have 4 nodes and 4 GPUs each node, so you spawn 16 workers. Every worker will have a rank [0, 15], and every worker will have a local_rank [0, 3]. You use local_rank for GPU pinning because there’s typically one GPU available on the node per process. It wouldn’t make sense to use rank here because rank could be 10, but you only have 4 GPUs so there is no GPU 10.

# 在其他import前引入
try:
    import horovod.tensorflow as hvd
    hvd.init()
except Exception as e:
    hvd = None
    print('no horovod')

# 打印信息
if hvd:
    tf.logging.info('Total workers: {}, local workers: {}'.format(
        hvd.size(), hvd.local_size()))
    tf.logging.info('Global rank: {}, local rank: {}'.format(
        hvd.rank(), hvd.local_rank()))

# 数据集读取配置:对数据集进行分片, 不同进程读取不同子集。
d = tf.data.TFRecordDataset(input_file)
if is_training:
    if hvd is not None:
        d = d.shard(hvd.size(), hvd.rank())
    d = d.shuffle(buffer_size=100)
    d = d.repeat()

# 加载权重配置:只对第一个rank载入权重
if init_checkpoint and is_training and (hvd is None or hvd.rank()==0):
    for init_file in init_checkpoint.split(","):
        assignment_map, tmp_init_map = get_assignment_map_from_checkpoint(tvars, init_file, extra_load_var)
        tf.train.init_from_checkpoint(init_file, assignment_map)
        initialized_variable_names.update(tmp_init_map)

# 学习率调整:
if hvd:
    learning_rate = learning_rate * hvd.size()

# 分布式优化器配置:使用 ring-allreduce 平均梯度
if hvd is not None:
    # we enable compression only for fp16
    from horovod.tensorflow.compression import Compression
    if use_fp16:
        compression = Compression.fp16
    else:
        compression = Compression.none

    optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True,
                                         compression=compression)

# 配置每个进程模型迭代次数
if FLAGS.do_train:
    # train_examples = processor.get_train_examples(FLAGS.data_dir, FLAGS.img_dir)
    num_train_steps = int(
        train_num / FLAGS.train_batch_size * FLAGS.num_train_epochs)
        # len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    if hvd:
        num_train_steps = num_train_steps // hvd.size()

model_fn = model_fn_builder(
    bert_config=bert_config,
    num_labels=len(label_list),
    init_checkpoint=FLAGS.init_checkpoint,
    learning_rate=FLAGS.learning_rate,
    num_train_steps=num_train_steps,
    num_warmup_steps=num_warmup_steps)

# GPU config GPU配置:使用local rank分配当前机器上当前进程可视gpu
run_config = tf.ConfigProto()
# train_params.get('gpu_allow_growth', False)
run_config.gpu_options.allow_growth = True
run_config.allow_soft_placement = True

if hvd:
    run_config.gpu_options.visible_device_list = str(hvd.local_rank())

if FLAGS.use_xla:
    run_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

# checkpoint配置:只对第一个保存模型
save_checkpoints_steps = FLAGS.save_checkpoints_steps if hvd is None or hvd.rank() == 0 else None
estimator = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir=FLAGS.output_dir,
    config=tf.estimator.RunConfig(
        save_checkpoints_steps=save_checkpoints_steps,
        save_checkpoints_secs=None,
        keep_checkpoint_every_n_hours=2,
        log_step_count_steps=400,
        session_config=run_config))


# 模型训练hook配置:将变量从第一个流程向其他流程传播,以实现一致性初始化。
if FLAGS.do_train and hvd is not None:
    training_hook = [hvd.BroadcastGlobalVariablesHook(0)]
else:
    training_hook = []
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps,
                hooks=training_hook)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_16234613/article/details/96186398

智能推荐

IOS学习—强引用(__strong)和 弱引用(__weak)_ios __strong-程序员宅基地

文章浏览阅读2.6k次。转载于开源中国在Objective-C的ARC模式中,id obj1 = [[NSObject alloc] init]; 这里虽然没有显示的声明为__strong,但是Objective-C默认声明的一个对象就为__strong,即: id obj1 = [[NSObject alloc] init]; 和 id __strong obj1 = [[NSObject alloc] init_ios __strong

ubuntu18.04下qt5.9编译错误: -1: error: cannot find -lGL_:-1: error: /usr/local/qt-5.9/lib/libqt5core.so: u-程序员宅基地

文章浏览阅读2.1k次。只要安装libGL即可:sudo apt-get install libqt4-devsudo apt update再重新编译就ok啦 _:-1: error: /usr/local/qt-5.9/lib/libqt5core.so: undefined reference to `uca

如何绘制深度神经网络图_深度学习神经网络图怎么画-程序员宅基地

文章浏览阅读2.7k次,点赞4次,收藏22次。1.在线版本的NN-SVG_深度学习神经网络图怎么画

菜鸟学习Android笔记-20140311_textview3:[i18n] hardcoded string-程序员宅基地

文章浏览阅读771次。1、编写布局文件时,遇到这样的警告,“[I18N] Hardcoded string "昵称:", should use @string resource” 原来的代码:

不同参数统计运行时间 large_integer c语言,使用LARGE_INTEGER查看系统运行时间-程序员宅基地

文章浏览阅读252次。众所周知,windows ce是一个实时操作,因此提供了不少的优先级给用户.优先级最高为0级,也就是说使用0优先级的程序, 可以挂起整个系统, 来运行你的程序对于实时性比较的领域, 我们作为程序员的 应该清楚的知道你的程序模块运行的时间 是非常必要的. 当然这个模块运行的时间也不是完全的稳定的, 几次运行的时间相差几十毫秒是很正常的. 因此我们只要知道大概的时间就可以了.当然, 大家..._large_integer计算时间

ssh登陆服务器locale告警(-bash: warning: setlocale:)的处理方法-程序员宅基地

文章浏览阅读1.9k次。 使用ssh远程登陆 IDC机房服务器,发现老是出现如下告警信息:-bash: warning: setlocale: LC_CTYPE: cannot change locale (en_US.UTF-8): No such file or directory-bash: warning: setlocale: LC_COLLATE: cannot change locale (en_U..._bash: warning: setlocale: lc_ctype: cannot change locale (en_us.utf-8): no s

随便推点

SQL中IF ELSE及MySQL伪列rownum的使用_mysql 如何使用if else 生成伪列-程序员宅基地

文章浏览阅读290次。编写SQL语句时难免会遇到各种条件判断,例如统计:count(case when then end)今天,我们要说的是if判断,eg:SELECT IF(c19='1','已评价','未评价')c19 FROM A05;关于伪列,广为人知的是oracle有伪列rownum,因为一些需求需要用mysql实现类似Oracle的伪列,方法方式如下:SELECT rowid, i01..._mysql 如何使用if else 生成伪列

【C++】虚函数及其内存布局_c++虚函数内存分布-程序员宅基地

文章浏览阅读1.7k次,点赞5次,收藏20次。一、函数调用捆绑把函数体与函数调用相联系称为捆绑。当捆绑在程序运行之前(由编译器和连接器)完成时,称为早捆绑。C编译只有一种函数调用方式,就是早捆绑。早捆绑引起的问题:因为编译器在只有对象的地址时它并不知道要调用的正确函数。根据对象的类型,捆绑发生在运行时,这种捆绑方式称为晚捆绑,又称动态捆绑。二、虚函数对于特定的函数,为了引起晚捆绑,C++要求在基类中声明这个函数时使用virtual关键字,这样的函数称为虚函数。晚捆绑只对virtual函数起作用,而且只在使用含有virtual函._c++虚函数内存分布

matlab 相位校正,科学网—全相位比值校正法 - 王兆华的博文-程序员宅基地

文章浏览阅读709次。加hann窗全相位比值校正法和加hann窗fft比值校正法校正方法类同,只须将二个振幅比改为振幅开方比即可。这里加hann窗是关键,过去测试时,直接调用Matlab中的hann(N)窗,频率和振幅校正效果差,见表5加hann窗全相位比值校正法测试结果。表4为加n-hann窗全相位比值校正法,其频率校正精度,相位校正精度和振幅校正精度都很高,甚至可以和表1加n-hann卷积窗apfft/apfft校..._比值校正法频谱校正matlab

创建登录界面_建网站登录页面-程序员宅基地

文章浏览阅读334次。package zhoushi;import javax.swing.*;//调用库import java.awt.*;import java.awt.event.*;public class jh extends JFrame implements ActionListener{//创建类jh继承JFrame,实现接口ActionListener JPanel log;//定义变量_建网站登录页面

win10安装linux虚拟机并配置shell工具连接_shell确认虚拟机光盘连接-程序员宅基地

文章浏览阅读1k次。1:虚拟机安装先看怎么用VMware安装一个虚拟机,全部放图,一步步来。主要还是以防以后我自己忘记怎么搞了,老了,记性不好了。VMware就在网上随便下载一个了,镜像我会在下面放上我的或者大家也可以自己去网上下一个。第一步:新建虚拟机第二步:选择类型第三步:选择映像文件,一般都会检测到,检测不到就只能自己打开浏览去找吧!第四步:给虚拟机命名,可以更改虚拟机安装位置。反正我是不会装在系统盘的,这辈子都不会的o(´^`)o第五步:默认选择是虚拟磁盘拆分成多个文件,但._shell确认虚拟机光盘连接

计算机视觉模型常用评价指标_平均交并比-程序员宅基地

文章浏览阅读3.6k次,点赞9次,收藏36次。分类任务常用准确率、精确率、召回率、F1_scores、ROC曲线等指标来评价模型的优劣,当然这些基础指标也可以用来评价分割模型或检测模型,它们基本上是可以通用的。混淆矩阵是对分类问题预测结果的总结,也是衡量分类型模型准确度中最基本,最直观,计算最简单的方法。混淆矩阵中含有4个分类问题的基础指标,如下表所示。........._平均交并比