Skip to content

Commit

Permalink
Merge pull request mmistakes#77 from pesser/scheduling_variables
Browse files Browse the repository at this point in the history
Scheduling variables
  • Loading branch information
pesser authored Jun 24, 2019
2 parents e0a1e35 + 9ad803c commit d7e1e63
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions edflow/tf_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf
import numpy as np


def make_linear_var(
Expand Down Expand Up @@ -50,3 +51,36 @@ def make_exponential_var(step, start, end, start_value, end_value, decay):
endstep = (np.log(end_value) - np.log(start_value)) / np.log(decay)
stepper = make_linear_var(step, start, end, startstep, endstep)
return tf.math.pow(decay, stepper) * start_value


def make_var(step, var_type, options):
"""
# usage within trainer
grad_weight = make_var(step=self.global_step,
var_type=self.config["grad_weight"]["var_type"],
options=self.config["grad_weight"]["options"])
# within yaml file
grad_weight:
var_type: linear
options:
start: 50000
end: 60000
start_value: 0.0
end_value: 1.0
clip_min: 1.0e-6
clip_max: 1.0
Parameters
----------
step
var_type
options
Returns
-------
"""
switch = {"linear": make_linear_var, "exponential": make_exponential_var}
return switch[var_type](step=step, **options)

0 comments on commit d7e1e63

Please sign in to comment.