Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer use init program #5275

Merged
5 changes: 5 additions & 0 deletions python/paddle/v2/framework/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
__all__ = ['Block', 'Variable', 'Program', 'Operator']


def unique_name(prefix):
uid = core.unique_integer(prefix) # unique during whole process.
return "_".join([prefix, str(uid)])


class Variable(object):
def __init__(self,
block,
Expand Down
23 changes: 13 additions & 10 deletions python/paddle/v2/framework/layer_helper.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import copy
import itertools

import paddle.v2.framework.core as core

from paddle.v2.framework.framework import Variable, g_program, \
g_init_program
g_init_program, unique_name, Program
from paddle.v2.framework.initializer import ConstantInitializer, \
UniformInitializer


def unique_name(prefix):
uid = core.unique_integer(prefix) # unique during whole process.
return "_".join([prefix, str(uid)])


class LayerHelper(object):
def __init__(self, layer_type, **kwargs):
self.kwargs = kwargs
Expand Down Expand Up @@ -138,9 +131,19 @@ def create_tmp_variable(self, dtype):
def create_variable(self, *args, **kwargs):
return self.program.current_block().create_var(*args, **kwargs)

def create_global_variable(self, *args, **kwargs):
def create_global_variable(self, persistable=False, *args, **kwargs):
return self.program.global_block().create_var(
*args, persistable=False, **kwargs)
*args, persistable=persistable, **kwargs)

def set_variable_initializer(self, var, initializer):
assert isinstance(var, Variable)
self.init_program.global_block().create_var(
name=var.name,
type=var.type,
dtype=var.data_type,
shape=var.shape,
persistable=True,
initializer=initializer)

def append_bias_op(self, input_var, num_flatten_dims=None):
"""
Expand Down
Loading