evox.workflows.std_workflow
#
模块内容#
类#
标准工作流。 |
API#
- class evox.workflows.std_workflow.StdWorkflow(opt_direction: str = 'min')[源代码]#
Bases:
evox.core.Workflow
标准工作流。
Usage:
algo = BasicAlgorithm(10) algo.setup(-10 * torch.ones(2), 10 * torch.ones(2)) prob = BasicProblem() class solution_transform(nn.Module): def forward(self, x: torch.Tensor): return x / 5 class fitness_transform(nn.Module): def forward(self, f: torch.Tensor): return -f monitor = EvalMonitor(full_sol_history=True) workflow = StdWorkflow() workflow.setup(algo, prob, solution_transform=solution_transform(), fitness_transform=fitness_transform(), monitor=monitor) monitor = workflow.get_submodule("monitor") workflow.init_step() print(monitor.topk_fitness) workflow.step() print(monitor.topk_fitness) # run rest of the steps ...
初始化
使用静态参数初始化标准工作流。
- 参数:
opt_direction -- 优化方向只能是“min”或“max”。默认为“min”。如果是“max”,适应度将在
fitness_transform
和monitor
之前取反。
- setup(algorithm: evox.core.Algorithm, problem: evox.core.Problem, monitor: evox.core.Monitor | None = None, solution_transform: torch.nn.Module | None = None, fitness_transform: torch.nn.Module | None = None, device: str | torch.device | int | None = None, algorithm_setup_params: Dict[str, Any] | None = None, problem_setup_params: Dict[str, Any] | None = None, monitor_setup_params: Dict[str, Any] | None = None)[源代码]#
将模块设置为子模块初始化。由于所有这些参数都是可变模块,将作为子模块添加,因此它们被放置在这里而不是
__init__
中,因此setup
必须在__init__
之后调用。- 参数:
algorithm -- 在工作流中要使用的算法。
problem -- 在工作流中要使用的问题。
monitors -- 在工作流中使用的监视器。默认为 None。注意:通常情况下,监视器只能在使用 JIT script 模式时使用。
solution_transform -- 解决方案转换函数。必须是兼容JIT的模块/函数,用于JIT跟踪模式,或者是JIT脚本模式(默认模式)的普通模块。默认为None。
fitness_transforms -- 适应度转换函数。必须是与JIT兼容的模块/函数,用于JIT追踪模式,或者是用于JIT脚本模式(默认模式)的普通模块。默认值为None。
device -- 工作流的设备。默认为 None。
algorithm_setup_params -- 要传递给
algorithm.setup(**kwargs)
的参数。如果未提供,则不会调用 `algorithm.setup()problem_setup_params -- 要传递给
problem.setup(**kwargs)
的参数。如果未提供,则不会调用 `problem.setup()monitor_setup_params -- 要传递给
monitor.setup(**kwargs)
的参数。如果未提供,则不会调用 `monitor.setup()
注意
算法、问题和监视器将被原地转换到目标设备。
- __getattribute__(name: str)#