[docs]classAdamax(Optimizer):"""Implements Adamax algorithm (a variant of Adam based on infinity norm). It has been proposed in `Adam: A Method for Stochastic Optimization`__. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 2e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) __ https://arxiv.org/abs/1412.6980 """def__init__(self,params,lr=2e-3,betas=(0.9,0.999),eps=1e-8,weight_decay=0):ifnot0.0<=lr:raiseValueError("Invalid learning rate: {}".format(lr))ifnot0.0<=eps:raiseValueError("Invalid epsilon value: {}".format(eps))ifnot0.0<=betas[0]<1.0:raiseValueError("Invalid beta parameter at index 0: {}".format(betas[0]))ifnot0.0<=betas[1]<1.0:raiseValueError("Invalid beta parameter at index 1: {}".format(betas[1]))ifnot0.0<=weight_decay:raiseValueError("Invalid weight_decay value: {}".format(weight_decay))defaults=dict(lr=lr,betas=betas,eps=eps,weight_decay=weight_decay)super(Adamax,self).__init__(params,defaults)
[docs]@torch.no_grad()defstep(self,closure=None):"""Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:params_with_grad=[]grads=[]exp_avgs=[]exp_infs=[]state_steps=[]beta1,beta2=group['betas']eps=group['eps']lr=group['lr']weight_decay=group['weight_decay']forpingroup['params']:ifp.gradisNone:continueparams_with_grad.append(p)ifp.grad.is_sparse:raiseRuntimeError('Adamax does not support sparse gradients')grads.append(p.grad)state=self.state[p]# State initializationiflen(state)==0:state['step']=0state['exp_avg']=torch.zeros_like(p,memory_format=torch.preserve_format)state['exp_inf']=torch.zeros_like(p,memory_format=torch.preserve_format)exp_avgs.append(state['exp_avg'])exp_infs.append(state['exp_inf'])state['step']+=1state_steps.append(state['step'])F.adamax(params_with_grad,grads,exp_avgs,exp_infs,state_steps,eps=eps,beta1=beta1,beta2=beta2,lr=lr,weight_decay=weight_decay)returnloss
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.