cd.callbacks
- class StepDropout(step_size, base_drop_rate, gamma=0.0, update_interval='epoch', log=True, log_name='drop_rate', ascending=False, **kwargs)
Step Dropout.
A simple Dropout Scheduler.
References
Examples
>>> from pytorch_lightning import Trainer >>> # Early Dropout (drop rate from .1 to 0 after 50 epochs) >>> trainer = Trainer(callbacks=[StepDropout(50, base_drop_rate=.1, gamma=0.)])
>>> # Late Dropout (drop rate from 0 to .1 after 50 epochs) >>> trainer = Trainer(callbacks=[StepDropout(50, base_drop_rate=.1, gamma=0., ascending=True)])
- Parameters:
step_size – Period of drop rate decay.
base_drop_rate – Base drop rate.
gamma – Multiplicative factor of drop rate decay. Default: 0. to replicate “Early Dropout”.
update_interval – One of
('step', 'epoch')
.log – Whether to log drop rates using
module.log(log_name, drop_rate)
.log_name – Name for logging.
logger – If
True
logs to the logger.ascending – If
True
drop rate decays from right to left, i.e. it starts at0
and ascends towardsbase_drop_rate
. Usingascending=True, gamma=0.
replicates “Late Dropout”.**kwargs – Keyword arguments for
module.log
.
- static get_rate(base, gamma, step, step_size, ascending)
- on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None
- on_train_epoch_start(trainer: Trainer, pl_module: LightningModule) None
- update_drop_rate(pl_module: LightningModule, drop_rate: float)