How to Update Trainable Variables on My Custom Optimizer Using Tensorflow 2

I'm a newbie studying Convolution Neural Network nowadays.

So, I have been implementing AlexNet referring a paper titled "ImageNet Classification with Deep Convolution Neural Networks", which uses Tensorflow 2.3 in anaconda environment. However, I got frustrated while I'm implementing a custom optimizer.

My problem : I have to modify the optimizer according to AlexNet paper. I can't find out the reference How to update the variable that uses TensorflowV2, I had googling though. There is only "tf.assign()" uses, It is not supported in Tensorflow V2, but I'm also afraid of the compatibility between V1 and V2 if I'm going to use this function.

I just know that I have to customize "_resource_apply_dense" function for adapting my update rule. Then, I loaded some hyperparameter in there. but I don't know how to update the hyperparameter.

(tf.Variable() can use as a python variable, so I suppose that it is the same as tf.Variable()....?)

Thank you all readers in advanced ^_^

Here's the code.

- update rule in AlexNet
 v_(i+1)= momentum * v_i- 0.0005 * learning_rate * w_i - learning_rate * gradient
 w_(i+1) = w_i + v_(i+1)

# where
# w_i = weight
# v_i = velocity
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import training_ops
from tensorflow.python.util.tf_export import keras_export
import tensorflow as tf

class AlexSGD(optimizer_v2.OptimizerV2):

    _HAS_AGGREGATE_GRAD = True

    def __init__(self,
                learning_rate=0.01,
                weight_decay=0.0005,
                momentum=0.9,
                name="AlexSGD",
                **kwargs):
        super(AlexSGD, self).__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
        self._set_hyper("decay", self._initial_decay)
        
        self._is_first = True
        self._set_hyper("vx", 0)
        self._set_hyper("pg", 0)
        self._set_hyper("pv", 0)
        self._weight_decay = False
        if isinstance(weight_decay, ops.Tensor) or callable(weight_decay) or 
            weight_decay > 0:
        self._weight_decay = True
        if isinstance(weight_decay, (int, float)) and (weight_decay < 0 or 
            weight_decay > 1):
        raise ValueError("`weight_decay` must be between [0, 1].")
        self._set_hyper("weight_decay", weight_decay)

        self._momentum = False
        if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
        self._momentum = True
        if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
        raise ValueError("`momentum` must be between [0, 1].")
        self._set_hyper("momentum", momentum)

    def _create_slots(self, var_list):
        if self._momentum:
        for var in var_list:
            self.add_slot(var, "momentum")
        if self._weight_decay:
        for var in var_list:
            self.add_slot(var, "weight_decay")
        for var in var_list:
        self.add_slot(var, 'pv') # previous variable i.e. weight or bias    
        for var in var_list:
        self.add_slot(var, 'pg') # previous gradient
        for var in var_list:
        self.add_slot(var, 'vx') # update velocity

    def _prepare_local(self, var_device, var_dtype, apply_state):
        super(AlexSGD, self)._prepare_local(var_device, var_dtype, apply_state)
        apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity(
            self._get_hyper("momentum", var_dtype))
        apply_state[(var_device, var_dtype)]["weight_decay"] = array_ops.identity(
            self._get_hyper("weight_decay", var_dtype))
        apply_state[(var_device, var_dtype)]["vx"] = array_ops.identity(
            self._get_hyper("vx", var_dtype))
        apply_state[(var_device, var_dtype)]["pv"] = array_ops.identity(
            self._get_hyper("pv", var_dtype))
        apply_state[(var_device, var_dtype)]["pg"] = array_ops.identity(
            self._get_hyper("pg", var_dtype))

    # main function
    @tf.function
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = ((apply_state or {}).get((var_device, var_dtype))
                        or self._fallback_apply_state(var_device, var_dtype))
        momentum_var = self.get_slot(var, "momentum")
        weight_decay_var = self.get_slot(var, "weight_decay")
        vx_var = self.get_slot(var, "vx")
        pv_var = self.get_slot(var, "pv")
        pg_var = self.get_slot(var, "pg")
        lr_t = self._decayed_lr(var_dtype)

        # update rule in AlexNet
        # v_(i+1) = momentum * v_i - 0.0005 * lr * w_i - lr * grad
        # w_(i+1) = w_i + v_(i+1)
        # where
        # w_i = var
        # vx, v_i = velocity (Feel like I need to set this variable as a slot) 
        # lr = learning_rate
        # grad = gradient
        
        # I'm confused why pv, pg variables are declared... 
        # does it replace by var & grad ?  (pv, pg refer from blog)
        # pv = previous var
        # pg = previous gradient

        if self._is_first:
        self._is_first = False
        vx_var = grad
        new_var = var + vx_var
        else:
        vx_var = momentum_var * vx_var - weight_decay_var*lr_t*pv_var- 
        lr_t*pg_var
        new_var = var + vx_var

        print("grad:",grad)
        print("var:",var)
        print("vx_var:",vx_var)
        print("new_var:",new_var)
        print("pv_var:",pv_var)
        print("pg_var:",pg_var)
        
        # TODO: I got stuck how update the variables because tf.assign() function 
        #       is deprecated in Tensorflow V2 
        pg_var = grad
        pv_var = var

        if var == new_var:
        var = new_var
        
        # TODO: In order to update variables, I can't find the equivalent 
        #        "tf.assign" method in TF V2
        # pg_var.assign(grad)
        # vx_var.assign(vx_var)
        # var.assign(new_var)
        
        
        
        """
        # TODO: I referred the below code from Tensorflow official document, and I 
        #       realized the training_ops module is in c++ library, So I thought I 
        #       can't modify it ( Cuz I need to modify an update function of 
        #       velocity 

        # return training_ops.resource_apply_keras_momentum(
        #     var.handle,
        #     momentum_var.handle,
        #     coefficients["lr_t"],
        #     grad,
        #     coefficients["momentum"],
        #     use_locking=self._use_locking,
        #     use_nesterov=self.nesterov)


        # if self._momentum :
        #  momentum_var = self.get_slot(var, "momentum")
        #   return training_ops.resource_apply_keras_momentum(
        #     var.handle,
        #     momentum_var.handle,
        #     coefficients["lr_t"],
        #     grad,
        #     coefficients["momentum"],
        #     use_locking=self._use_locking,
        #     use_nesterov=self.nesterov)
        # else:
        #   return training_ops.resource_apply_gradient_descent(
        #       var.handle, coefficients["lr_t"], grad, 
                use_locking=self._use_locking)
        """


    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        raise NotImplementedError

    def get_config(self):
        config = super(AlexSGD, self).get_config()
        config.update({
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "weight_decay": self._serialize_hyperparameter("weight_decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
        })
        return config

Reference: https://www.kdnuggets.com/2018/01/custom-optimizer-tensorflow.html

This is another Reference, it also uses tf.assign()

from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.framework import ops
from tensorflow.python.training import optimizer
import tensorflow as tf

class AlexOptimizer(optimizer.Optimizer):
    def __init__(self, learning_rate="learning_rate",alpha="alpha",beta="beta", #weight_decay="weight_decay", use_locking=False, name="AlexOptimizer"):
        super(AlexOptimizer, self).__init__(use_locking, name)
        self._lr = learning_rate
        self._wd = weight_decay
        self._alpha = alpha
        self._beta = beta
        # Tensor versions of the constructor arguments, created in _prepare().
        self._lr_t = None
        self._wd_t = None
        self._alpha_t = None
        self._beta_t = None

    def _prepare(self):
        self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
        self._wd_t = ops.convert_to_tensor(self._wd, name="weight_decay")
        self._alpha_t = ops.convert_to_tensor(self._beta, name="alpha_t")
        self._beta_t = ops.convert_to_tensor(self._beta, name="beta_t")

    def _create_slots(self, var_list):
        # Create slots for the first and second moments.
        for v in var_list:
            self._zeros_slot(v, "m", self._name)

    def _apply_dense(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        wd_t = math_ops.cast(self._wd_t, var.dtype.base_dtype)
        alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype)
        beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype)

        eps = 1e-7 #cap for moving average
        m = self.get_slot(var, "m")
        m_t = m.assign(tf.maximum(beta_t * m + eps, tf.abs(grad)))

        var_update = state_ops.assign_sub(var, lr_t*grad*tf.exp( tf.log(alpha_t)*tf.sign(grad)*tf.sign(m_t))) 
        # Update 'ref' by subtracting value
        # Create an op that groups multiple operations.
        # When this op finishes, all ops in input have finished
        return control_flow_ops.group(*[var_update, m_t])
    def _apply_sparse(self, grad, var):
        raise NotImplementedError("Sparse gradient updates are not supported.")

I would like to modify the code using OptimizerV2, What should I update the variables?

(p.s.) is it right the "@tf.function" usage above "def _resource_apply_dense()"?

on the other hand, my model is continuously shuffling on and on during training ㅠ_ㅠ (this code in a procedure of dataset preprocessing (tf.data.datsets.shuffle()) even it doesn't exist in while loop.......... (sorry for not posting this code... so never mind....)