cerebras.modelzoo.common.run_cstorch_flow.run_cstorch_train#
- cerebras.modelzoo.common.run_cstorch_flow.run_cstorch_train(params, model_fn, input_fn, cs_config, artifact_dir)[source]#
Runs the training workflow built using the cstorch API
- Parameters
params – the params dictionary extracted from the params.yaml used
model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module
input_data – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader