Wandb使用指南

发布于:2024-08-18 ⋅ 阅读:(338) ⋅ 点赞:(0)

安装:

pip install wandb

登录

wanbd login

在terminal中操作查看你的API key并粘贴回车进行授权(https://wandb.ai/authorize

设置离线模式/在线模式

设置为offline会在无网络(内网)的时候使用,常用于debug的时候使用,因为这样启动速度快

注意:设置offline要在wandb init之前使用,否则不起作用

# 设置为离线模式,常用于测试、debug的时候,因为在线模式启动速度慢
import os
os.environ["WANDB_MODE"]="offline"

代码中初始化wandb

# 初始化wandb
wandb.init(project="config_example")

如果记录参数和日志?

1)记录运行args参数

有很多方式,但是最常用的方式是直接通过一行代码上传args:

# 储存运行参数:将参数值转为dict,然后再储存
wandb.config.update(vars(args))

上传后的参数储存在overview中: 

2)记录运行log + images

wandb.log()记录这些值

然后image要通过一步wandb.Image(image)转换才可以存储

        # 储存运行过程中的图像:随机生成一个图像作为示例
        data = np.random.rand(256, 256, 3) * 255
        data = data.astype(np.uint8)
        image = Image.fromarray(data, 'RGB')

        # 储存运行过程中的loss等日志
        wandb.log(
            {
                "epoch": epoch,
                "train_acc": train_acc,
                "train_loss": train_loss,
                "val_acc": val_acc,
                "val_loss": val_loss,
                'images': wandb.Image(image),
            }
        )

  3)记录某些文本

记录总结性的文本

例如:参数量

wandb.run.summary['Trainable parameters'] = f"{n_params / 1e6}M"

记录带有格式的文本 

某些时候可能需要记录model的构造等等,我们需要使用:

    wandb.log(
        {"Model_architecture": wandb.Table(columns=["Model_architecture"], data=[[str(model_without_ddp)]])}
    )

结果查看: 

运行代码后,会出现日志,直接点击本次运行结果的连接即可

完整示例代码执行:

import wandb
import argparse
import numpy as np
import random
from PIL import Image

# 初始化wandb
wandb.init(project="config_example")


def train_one_epoch(epoch, lr, bs):
    acc = 0.25 + ((epoch / 30) + (random.random() / 10))
    loss = 0.2 + (1 - ((epoch - 1) / 10 + random.random() / 5))
    return acc, loss


def evaluate_one_epoch(epoch):
    acc = 0.1 + ((epoch / 20) + (random.random() / 10))
    loss = 0.25 + (1 - ((epoch - 1) / 10 + random.random() / 6))
    return acc, loss


def main(args):
    # 储存运行参数:将参数值转为dict,然后再储存
    wandb.config.update(vars(args))

    for epoch in np.arange(1, args.epochs):
        train_acc, train_loss = train_one_epoch(epoch, args.learning_rate, args.batch_size)
        val_acc, val_loss = evaluate_one_epoch(epoch)

        # 储存运行过程中的图像:随机生成一个图像作为示例
        data = np.random.rand(256, 256, 3) * 255
        data = data.astype(np.uint8)
        image = Image.fromarray(data, 'RGB')

        # 储存运行过程中的loss等日志
        wandb.log(
            {
                "epoch": epoch,
                "train_acc": train_acc,
                "train_loss": train_loss,
                "val_acc": val_acc,
                "val_loss": val_loss,
                'images': wandb.Image(image),
            }
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")
    parser.add_argument("--learning_rate", type=int, default=0.001, help="Learning rate")

    args = parser.parse_args()
    main(args)


网站公告

今日签到

点亮在社区的每一天
去签到