habitat中的坑(一):训练模型的时候找不到数据

发布于:2024-03-11 ⋅ 阅读:(103) ⋅ 点赞:(0)

在habitat中训练一个模型需要指定配置文件,(根据目前的学习)一般要指定两个yaml文件:

  • 一个是训练的配置文件
  • 一个是任务的配置文件

举例如下:

import random
import numpy as np
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.config.default import get_config as get_baselines_config
import torch

if __name__ == "__main__":
    run_type = "train"      #指定是训练还是评估
    #指定训练配置文件
    config = get_baselines_config("../habitat_baselines/config/pointnav/ppo_pointnav_example.yaml")

  #下面是在代码中对一些配置参数进行修改
    config.defrost()
    config.TASK_CONFIG.DATASET.DATA_PATH="/home/yons/LK/skill_transformer-main/data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"
    config.TASK_CONFIG.DATASET.SCENES_DIR="/home/yons/LK/skill_transformer-main/data/scene_datasets"
    config.freeze()
    
    random.seed(config.TASK_CONFIG.SEED)
    np.random.seed(config.TASK_CONFIG.SEED)
    torch.manual_seed(config.TASK_CONFIG.SEED)
    if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():
        torch.set_num_threads(1)

    trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME)###config.TRAINER_NAME指定模型名字
    assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported"
    trainer = trainer_init(config)

    if run_type == "train":
        trainer.train()
    elif run_type == "eval":
        trainer.eval()

上面所指定的训练文件ppo_pointnav_example.yaml中有一个配置项如下:

BASE_TASK_CONFIG_PATH: "../configs/tasks/pointnav.yaml"

从上面的代码可以看出来在代码中指定训练的配置文件,在训练配置文件中配置任务配置文件。

训练过程肯定要指定数据集(TASK_CONFIG.DATASET.DATA_PATH)(在训练配置文件中配置还是在任务配置文件中配置?目前至少看到在任务配置文件中是可以的)。

如果TASK_CONFIG.DATASET.DATA_PATH没有重新指定,会有默认值(目前知道有些默认值是从…/habitat-lab/habitat_baselines/config/default.py 中定义的)。

如果是点导航任务,需要同时指定正确的DATA_PATH和SCENES_DIR,否则会报错Could not find dataset file

具体原因见下面的代码
文件位置:.../habitat/datasets/pointnav/pointnav_dataset.py

@registry.register_dataset(name="PointNav-v1")
class PointNavDatasetV1(Dataset):
    r"""Class inherited from Dataset that loads Point Navigation dataset."""

    episodes: List[NavigationEpisode]
    content_scenes_path: str = "{data_path}/content/{scene}.json.gz"

    @staticmethod
    def check_config_paths_exist(config: Config) -> bool:
        return os.path.exists(
            config.DATA_PATH.format(split=config.SPLIT)
        ) and os.path.exists(config.SCENES_DIR)

    @classmethod
    def get_scenes_to_load(cls, config: Config) -> List[str]:
        r"""Return list of scene ids for which dataset has separate files with
        episodes.
        """
        dataset_dir = os.path.dirname(
            config.DATA_PATH.format(split=config.SPLIT)
        )
        if not cls.check_config_paths_exist(config):
            raise FileNotFoundError(
                f"Could not find dataset file `{dataset_dir}`"
            )
本文含有隐藏内容,请 开通VIP 后查看

网站公告


今日签到

点亮在社区的每一天
去签到