Using the CerebrasEstimator
On This Page
Using the CerebrasEstimator¶
The CerebrasEstimator is a critical part of your main Python program when running on the CS system. It is the CerebrasEstimator that launches the Cerebras Graph Compiler (CGC) when its methods such as compile, or train are called while providing the IP address of the CS system with cs_ip. See also The CerebrasEstimator Interface.
In this section, an example run.py template is used to show how the CerebrasEstimator interacts with the key code segments of your Python program.
Note
For a detailed description of the example run.py template, see The run.py Template.
Shown below is a highly simplified run.py example code that is used for neural network training:
1 # Example run.py script for neural network training
2 from cerebras.models.common.estimator.tf.cs_estimator import CerebrasEstimator
3 from cerebras.models.common.estimator.tf.run_config import CSRunConfig
4 from cerebras.tf.cs_slurm_cluster_resolver import CSSlurmClusterResolver
5
6 def model_fn(features, labels, mode, params):
7
8 ...
9
10 return spec
11
12 def input_fn(params):
13
14 ...
15
16 return dataset
17
18 config = CSRunConfig(
19 cs_ip=ip,
20 save_checkpoints_steps=1000,
21 log_step_count_steps=10000,
22 "use_cbfloat16": True )
23 params ={
24 "batch_size":32,
25 "lr":0.1,
26 "use_cbfloat16": True
27 }
28
29 est = CerebrasEstimator(
30 model_fn,
31 config=config,
32 params=params,
33 model_dir='./out',
34 use_cs=True
35 )
36
37 est.train(input_fn, steps=100000)
Calling the CerebrasEstimator¶
In the est=CerebrasEstimator(...) call (line 29), the model_fn argument is a callback function. When the CerebrasEstimator receives this argument, the CerebrasEstimator API waits until one of its methods, train, is invoked.
Note
The
model_fnargument to theCerebrasEstimatorinterface is passed without the().
Callback input function¶
The
est.train (input_fn, steps=100000)(line 37) is atrainmethod call to theCerebrasEstimatorwithinput_fnargument as a callback function. TheCerebrasEstimatorthen calls theinput_fnwith theparamsargument.Note
The
input_fnargument to thetrainmethod is passed without the().Both the
CerebrasEstimatorand TensorFlow Estimator API expect the input function to:Accept a standard group of input parameters with the argument
paramsandReturns a
tf.data.Datasetthat yields tensor pairs in the predefined format: tensor with features and tensor with labeles.
Any
paramspassed to theCerebrasEstimatorare passed on to theinput_fnand to themodel_fn. when theCerebrasEstimatorcalls theinput_fn.The
input_fnshould return atf.data.Dataset(see Dataset API for documentation).The input function builds the input pipeline and yields the batched data in the form of
(features, labels)pairs, where:featurescan be a tensor or dictionary of tensors, andlabelscan be a tensor, a dictionary of tensors or None.
Example¶
def input_fn(params):
...
ds = ds.shuffle(buffer_size)
ds = ds.repeat()
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(buffer_size)
return ds
Callback model function¶
The model function model_fn is used to generate the graph for your neural network model.
The
featuresandlabels, the two arguments returned from theinput_fn, are the handles to the batched data that your model will use. When these two arguments,featuresandlabels, are returned from theinput_fn, theCerebrasEstimatorwill then call themodel_fnby passing the following arguments to themodel_fn:The
modeargument that indicates whether the caller is requesting training.The
paramsobject that was passed in theest=CerebrasEstimator(...)call.
Important
The functions
input_fnand themodel_fnare called by theCerebrasEstimatoras these two are passed to theCerebrasEstimatoras callback functions. You should not directly call either of these two functions in your TensorFlow code.
Both the CerebrasEstimator and TensorFlow Estimator API expect the model function to accept a standard group of input parameters and return a standard group of output values.
Currently, the CerebrasEstimator supports usage of the Tensorflow Keras Layers API in
the model function. However, the Tensorflow Metrics API is not supported.
Syntax¶
def model_fn(
features, # This is batch_features from input_fn
labels, # This is batch_labels from input_fn
mode, # An instance of tf.estimator.ModeKeys
params # Additional configuration
):
Example¶
See below an example of model_fn definition.
def model_fn(features, labels, mode=tf.estimator.ModeKeys.TRAIN, params=None):
""" Model definition """
logits = build_model(features, params)
learning_rate = tf.constant(params["lr"])
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
loss_op = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits) )
train_op = tf.train.GradientDescentOptimizer(learning_rate=learning_rate ).minimize(loss_op, global_step=tf.train.get_global_step())
spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss_op, train_op=train_op)
return spec
Setting the runtime configuration¶
Runtime and environment options can be set. Usually this is the information that is not captured in the model_fn and input_fn. Use the CSRunConfig object to set these Cerebras-specific options. These options are an extension of TensorFlow RunConfig.
Important
Make sure to add the following import statement to your Slurm-orchestrated TensorFlow code so that Slurm cluster resolving is done automatically.
from cerebras.tf.cs_slurm_cluster_resolver import CSSlurmClusterResolver
CSRunConfig¶
The Cerebras CSRunConfig class inherits from the standard TensorFlow RunConfig class. You can pass to the CSRunConfig the same parameters as those of the Tensorflow RunConfig, and also pass additional parameters that specify the configurations for a CerebrasEstimator run, including the IP address of the CS system. Such additional parameters include:
cs_ip: IP address of the CS system, provided by Cerebras.system_name: Name of the CS system.
The full list of options for TensorFlow RunConfig can be found
here.
Example¶
from cerebras.models.common.estimator.tf.run_config import CSRunConfig
from cerebras.tf.cs_slurm_cluster_resolver import CSSlurmClusterResolver
config = CSRunConfig(
cs_ip=ip,
save_checkpoints_steps=1000,
log_step_count_steps=10000,
save_summary_steps=1000
)