evox.problems.neuroevolution.supervised_learning
¶
模块内容¶
类¶
监督学习问题是使用给定的数据和标准来测试模型参数或一批参数。 |
API¶
- class evox.problems.neuroevolution.supervised_learning.SupervisedLearningProblem(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, n_batch_per_eval: int = 1, pop_size: int | None = None, device: torch.device | None = None, reduction: str = 'mean')[源代码]¶
Bases:
evox.core.Problem
监督学习问题是使用给定的数据和标准来测试模型参数或一批参数。
初始化
初始化
SupervisedLearningProblem
。- 参数:
model -- 需要评估参数的神经网络模型。
data_loader -- 用于评估的数据加载器提供数据集。
criterion -- 用于评估参数性能的损失函数。
n_batch_per_eval -- 每次评估中要计算的批次数。当设置为 -1 时,将遍历整个数据集。默认值为 1。
pop_size -- 种群的大小(参数的批量大小)需要进行评估。默认为 None,表示单次运行模式。
device -- 用于运行计算的设备。默认为当前默认设备。
reduction -- 用于标准的归约方法。'mean' | 'sum'。默认值为“mean”。
- 抛出:
RuntimeError -- 如果数据加载器不包含任何项目。
警告
此问题不支持 HPO 包装器 (
problems.hpo_wrapper.HPOProblemWrapper
),即包含此问题的工作流不能被 vmapped。- _vmap_forward_pass(model_state: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], data: Tuple[torch.Tensor, torch.Tensor])[源代码]¶