common.pytorch package#
Subpackages#
- common.pytorch.layers package
- Submodules
- common.pytorch.layers.AlibiPositionEmbeddingLayer module
- common.pytorch.layers.AttentionHelper module
- common.pytorch.layers.AttentionLayer module
- common.pytorch.layers.BCELoss module
- common.pytorch.layers.BCEWithLogitsLoss module
- common.pytorch.layers.BiaslessLayerNorm module
- common.pytorch.layers.CTCLoss module
- common.pytorch.layers.CosineEmbeddingLoss module
- common.pytorch.layers.CrossEntropyLoss module
- common.pytorch.layers.EmbeddingLayer module
- common.pytorch.layers.FeedForwardNetwork module
- common.pytorch.layers.GPTJDecoderLayer module
- common.pytorch.layers.GaussianNLLLoss module
- common.pytorch.layers.HingeEmbeddingLoss module
- common.pytorch.layers.HuberLoss module
- common.pytorch.layers.KLDivLoss module
- common.pytorch.layers.L1Loss module
- common.pytorch.layers.MSELoss module
- common.pytorch.layers.MarginRankingLoss module
- common.pytorch.layers.MultiLabelSoftMarginLoss module
- common.pytorch.layers.MultiMarginLoss module
- common.pytorch.layers.NLLLoss module
- common.pytorch.layers.PoissonNLLLoss module
- common.pytorch.layers.RMSNorm module
- common.pytorch.layers.RelativePositionEmbeddingLayer module
- common.pytorch.layers.SmoothL1Loss module
- common.pytorch.layers.Transformer module
- common.pytorch.layers.TransformerDecoder module
- common.pytorch.layers.TransformerDecoderLayer module
- common.pytorch.layers.TransformerEncoder module
- common.pytorch.layers.TransformerEncoderLayer module
- common.pytorch.layers.TripletMarginLoss module
- common.pytorch.layers.TripletMarginWithDistanceLoss module
- common.pytorch.layers.utils module
- Module contents
- common.pytorch.metrics package
- Submodules
- common.pytorch.metrics.accuracy module
- common.pytorch.metrics.auc module
- common.pytorch.metrics.cb_metric module
- common.pytorch.metrics.dice_coefficient module
- common.pytorch.metrics.fbeta_score module
- common.pytorch.metrics.mean_iou module
- common.pytorch.metrics.mean_per_class_accuracy module
- common.pytorch.metrics.metric_utils module
- common.pytorch.metrics.perplexity module
- common.pytorch.metrics.precision_at_k module
- common.pytorch.metrics.recall_at_k module
- common.pytorch.metrics.rouge_score module
- Module contents
- common.pytorch.model_utils package
- Subpackages
- common.pytorch.model_utils.checkpoint_converters package
- Submodules
- common.pytorch.model_utils.checkpoint_converters.base_converter module
- common.pytorch.model_utils.checkpoint_converters.bert module
- common.pytorch.model_utils.checkpoint_converters.bert_finetune module
- common.pytorch.model_utils.checkpoint_converters.bloom_hf_cs module
- common.pytorch.model_utils.checkpoint_converters.gpt2_hf_cs module
- common.pytorch.model_utils.checkpoint_converters.gpt_neox_hf_cs module
- common.pytorch.model_utils.checkpoint_converters.gptj_hf_cs module
- common.pytorch.model_utils.checkpoint_converters.llama module
- common.pytorch.model_utils.checkpoint_converters.opt_hf_cs module
- common.pytorch.model_utils.checkpoint_converters.salesforce_codegen_hf_cs module
- common.pytorch.model_utils.checkpoint_converters.t5 module
- Module contents
- common.pytorch.model_utils.checkpoint_converters package
- Submodules
- common.pytorch.model_utils.BertPretrainModelLoss module
- common.pytorch.model_utils.GPTLMHeadModelLoss module
- common.pytorch.model_utils.RotaryPositionEmbeddingHelper module
- common.pytorch.model_utils.T5ForConditionalGenerationLoss module
- common.pytorch.model_utils.activations module
- common.pytorch.model_utils.convert_checkpoint module
- common.pytorch.model_utils.create_initializer module
- common.pytorch.model_utils.weight_initializers module
- Module contents
- Subpackages
- common.pytorch.optim package
- Submodules
- common.pytorch.optim.ASGD module
- common.pytorch.optim.Adadelta module
- common.pytorch.optim.Adafactor module
- common.pytorch.optim.Adagrad module
- common.pytorch.optim.AdamBase module
- common.pytorch.optim.Adamax module
- common.pytorch.optim.CSOptimizer module
- common.pytorch.optim.Lamb module
- common.pytorch.optim.Lion module
- common.pytorch.optim.NAdam module
- common.pytorch.optim.RAdam module
- common.pytorch.optim.RMSprop module
- common.pytorch.optim.Rprop module
- common.pytorch.optim.SGD module
- common.pytorch.optim.lr_scheduler module
- common.pytorch.optim.utils module
- Module contents
- common.pytorch.sparsity package
- common.pytorch.summaries package
Submodules#
common.pytorch.PyTorchBaseModel module#
Abstract base class for PyTorch models.
- class common.pytorch.PyTorchBaseModel.Final[source]#
Bases:
type
Placeholder class for deprecation warning
- class common.pytorch.PyTorchBaseModel.PyTorchBaseModel[source]#
Bases:
object
Base Model Definition for Cerebras runners
- __init__(params: dict, model_fn: Union[Callable[[dict], torch.nn.Module], torch.nn.Module], device: Optional[torch.device] = None)[source]#
- property duplicate_params_map#
Returns a map of param names which hold the same tensors key and value are same as the names that appear in state_dict
- property supported_cs_modes#
Returns a list of modes that are supported for CS runs.
By default we support train and eval, however, this property is designed to be overriden on a model-by-model basis.
- property supported_modes#
Supported modes conditional on hardware backend
- property supported_non_cs_modes#
Returns a list of modes that are supported for non-CS (CPU/GPU) runs.
By default we support train, eval and train_and_eval, however, this property is designed to be overriden on a model-by-model basis.
common.pytorch.dump_context module#
Provides DumpContext, a debug utility for dumping activations and gradients on a CPU/GPU run, and setting up debug names for dumped WSE activations to be automatically correlated.
- class common.pytorch.dump_context.DumpContext[source]#
Bases:
object
A debug utility context manager. When provided with a torch.nn.Module, the resulting context manager can be entered to enable dumping of all module forward and backward outputs to a npz, for comparing numerics between implementations.
Sets up global module hoooks to either dump intermediate activations on CPU/GPU or name the traced tensors for correlating with debug dumps on CS2.
The recursive name of the torch.nn.Module is memoized, and the output of FWD and BWD of each module is saved as keys in a .npz file.
- Parameters
outdir – Where to output dumps_{i}.npz
model – root module to name its children
buffer_steps – If given, flush to a new .npz file after this many steps
- __init__(outdir: str, model: torch.nn.Module, buffer_steps: Optional[int] = None)[source]#
Sets up global module hoooks to either dump intermediate activations on CPU/GPU or name the traced tensors for correlating with debug dumps on CS2.
The recursive name of the torch.nn.Module is memoized, and the output of FWD and BWD of each module is saved as keys in a .npz file.
- Parameters
outdir – Where to output dumps_{i}.npz
model – root module to name its children
buffer_steps – If given, flush to a new .npz file after this many steps
- disable_collection()[source]#
Uninstall the hooks installed during enable_collection, disabling further dump collection.
common.pytorch.gradient_clipper module#
common.pytorch.half_dtype module#
Module which provides utilities for selecting half dtype between float16 and bfloat16
common.pytorch.input_utils module#
- common.pytorch.input_utils.bucketed_batch(data_iterator, batch_size, buckets=None, element_length_fn=None, collate_fn=None, drop_last=False, seed=None)[source]#
Batch the data from an iterator such that sampels of similar length end up in the same batch. If buckets is not supplied, then this just batches the dataset normally.
- Parameters
data_iterator – An iterater that yields data one sample at a time.
batch_size (int) – The number of samples in a batch.
buckets (list) – A list of bucket boundaries. If set to None, then no bucketing will happen, and data will be batched normally. If set to a list, then data will be grouped into len(buckets) + 1 buckets. A sample s will go into bucket i if buckets[i-1] <= element_length_fn(s) < buckets[i] where 0 and inf are the implied lowest and highest boundaries respectively. buckets must be sorted and all elements must be non-zero.
element_length_fn (callable) – A function that takes a single sample and returns an int representing the length of that sample.
collate_fn (callable) – The function to use to collate samples into a batch. Defaults to PyTorch’s default collate function.
drop_last (bool) – Whether or not to drop incomplete batches at the end of the dataset. If using bucketing, buckets that are not completely full will also be dropped, even if combined there are more than batch_size samples remaining spread across multiple buckets.
seed (int) – If using drop_last = False, we don’t want to feed out leftover samples with order correlated to their lengths. The solution is to shuffle the leftover samples before batching and yielding them. This seed gives the option to make this shuffle deterministic. It is only used when buckets is not None and drop_last = True.
- Yields
Batches of samples of type returned by collate_fn, or batches of PyTorch tensors if using the default collate function.
- common.pytorch.input_utils.get_streaming_batch_size(effective_batch_size: int) int [source]#
Returns the streaming batch size of the current task.
In a Wafer-Scaler Cluster setup with more than 1 CS-X node, the batch size used in compile and specified by user is the effective batch size at which gradient updates are done. However, each worker node streams a local batch of data to a given CS-X node to enable data parallel training.
This helper method returns the local batch size that the current task should use given the effective batch size.
- Parameters
effective_batch_size – The effective batch size of the model.
- Returns
The local batch size to be streamed by this task. If queried on the user node (used when compiling the model), this returns the original effective batch size as passed in the argument.
common.pytorch.loss_utils module#
Module which provides utilities for aggregating and saving losses
- class common.pytorch.loss_utils.LossSaver[source]#
Bases:
object
Helper class for storing losses during training/eval.
Constructs a LossSaver instance.
- Parameters
writer – Tensorboard summary writer for writing losses.
- __init__(writer: Optional[torch.utils.tensorboard.SummaryWriter] = None)[source]#
Constructs a LossSaver instance.
- Parameters
writer – Tensorboard summary writer for writing losses.
- accumulate(loss: torch.Tensor)[source]#
Accumulates loss values. This method will reduce losses across workers and update a total_loss
- Parameters
loss – The loss tensor whose value will be stored.
- add(loss: torch.Tensor, step: int, epoch: Optional[int] = None)[source]#
Store loss value. This method will reduce losses across workers.
- Parameters
loss – The loss tensor whose value will be stored.
step – Global step at which loss was computed.
epoch – The current epoch.
- property average_loss: float#
Return the total accumulated loss
- static mean_reduce(vals: list)[source]#
Apply mean reduction over values.
- Parameters
vals – List of values to apply mean reduction over.
- Returns
The mean reduction of values.
- property total_loss: float#
Return the total accumulated loss
common.pytorch.perf_utils module#
- class common.pytorch.perf_utils.PerfData[source]#
Bases:
object
Data structure for holding performance data.
- Parameters
total_samples – Total number of samples processes.
total_time – Total time spent processing those samples.
samples_per_sec – Throuput of processing those samples.
compile_time – Time spent compiling the model.
programming_time – Time spent programming the fabric.
est_samples_per_sec – Estimated throughput based on compile and fabric.
- __init__(total_samples: int = 0, total_time: float = 0.0, samples_per_sec: float = 0.0, compile_time: float = 0.0, programming_time: float = 0.0, est_samples_per_sec: float = 0.0) None #
- compile_time: float = 0.0#
- est_samples_per_sec: float = 0.0#
- merge(other: common.pytorch.perf_utils.PerfData)[source]#
Merge another PerfData instance into self.
- Parameters
other – The other PerfData instance to merge.
- programming_time: float = 0.0#
- samples_per_sec: float = 0.0#
- total_samples: int = 0#
- total_time: float = 0.0#
common.pytorch.pytorch_base_cs_runner module#
Module containing the Base PyTorch CS Runner
common.pytorch.pytorch_base_runner module#
Modulek containing the Base PyTorch Runner
- class common.pytorch.pytorch_base_runner.PyTorchBaseRunner[source]#
Bases:
object
The base class for running PyTorch models on any device.
Construct a PyTorchRunner instance.
- Parameters
model – The PyTorch model to run.
param – A dict of params that specify the behavior of the model.
- __init__(model: modelzoo.common.pytorch.PyTorchBaseModel.PyTorchBaseModel, params: dict)[source]#
Construct a PyTorchRunner instance.
- Parameters
model – The PyTorch model to run.
param – A dict of params that specify the behavior of the model.
- backward(loss)[source]#
Runs the backward pass.
Override this method to provide any additional functionality around the backward call.
- static create(model_fn: Callable[[dict, Optional[torch.device]], modelzoo.common.pytorch.PyTorchBaseModel.PyTorchBaseModel], params: dict) common.pytorch.pytorch_base_runner.PyTorchBaseRunner [source]#
Creates and returns an instance of PyTorchBaseRunner that has been configured based on the hardware specified by the provided params dictionary
- Parameters
model_fn – A callable that takes in a ‘params’ argument and optionally a torch.device which it uses to configure and return a PyTorchBaseModel
params – A dictionary containing all the parameters required to initialize and configure both the model and the runner
- eval_epoch(dataloader, epoch: Optional[int] = None)#
Runs an epoch of evaluation
- Parameters
dataloader – The dataloader to iterate through
epoch – The current epoch
- eval_forward(data)[source]#
Runs the eval forward pass.
Override this method to provide any additional functionality around the eval forward pass call.
- evaluate(eval_dataloader: torch.utils.data.DataLoader)[source]#
Evaluate the model with data generated by the given dataloader.
- Parameters
dataloader – A data loader for generating data to feed to the model.
- is_master_ordinal()[source]#
Checks if distributed if enabled and if so whether it’s the main process, most reading and writing should only happens on main process.
- on_checkpoint_saved(checkpoint_path: str, step: int)[source]#
Function to execute after a checkpoint is saved.
- on_eval_batch_end(loss, epoch: Optional[int] = None, step: Optional[int] = None)[source]#
Actions to perform after the eval batch iteration is complete
- on_train_batch_end(loss, epoch: Optional[int] = None, step: Optional[int] = None)[source]#
Actions to perform after the train batch iteration is complete
- train(train_dataloader: torch.utils.data.DataLoader)[source]#
Train the model with data generated by the given dataloader.
- Parameters
dataloader – A data loader for generating data to feed to the model.
- train_and_eval(train_dataloader: torch.utils.data.DataLoader, eval_dataloader: torch.utils.data.DataLoader)[source]#
Train and evaluate the model with data generated by dataloaders.
In each epoch, this method trains the model first, then runs evaluation every epoch.
- Parameters
train_dataloader – A data loader for generating training data to feed to the model.
eval_dataloader – A data loader for generating evaluation data to feed to the model.
common.pytorch.pytorch_cs_appliance module#
Contains the CS Appliance mode runner
- class common.pytorch.pytorch_cs_appliance.PyTorchCSAppliance[source]#
Bases:
modelzoo.common.pytorch.pytorch_base_cs_runner.PyTorchBaseCSRunner
Class for compiling PyTorch models for Cerebras hardware.
common.pytorch.pytorch_dist_runner module#
- class common.pytorch.pytorch_dist_runner.PyTorchDistRunner[source]#
Bases:
modelzoo.common.pytorch.pytorch_runner.PyTorchRunner
Class for running PyTorch models on multiple GPUs.
- evaluate(eval_data_fn)[source]#
Evaluate the model with data generated by the given dataloader.
- Parameters
dataloader – A data loader for generating data to feed to the model.
- is_master_ordinal()[source]#
Checks if distributed if enabled and if so whether it’s the main process, most reading and writing should only happens on main process.
- on_eval_batch_end(loss, epoch: Optional[int] = None, step: Optional[int] = None)[source]#
Actions to perform after the eval batch iteration is complete
common.pytorch.pytorch_runner module#
common.pytorch.run_cstorch_flow module#
Generic run scripts build using the cstorch API
- common.pytorch.run_cstorch_flow.compute_grad_norm(model)[source]#
Compute the model wise and per layer norm of the gradients
- common.pytorch.run_cstorch_flow.compute_params_norm(model)[source]#
Compute the model wise norm of the parameters
- common.pytorch.run_cstorch_flow.get_latest_checkpoint(model_dir)[source]#
Get the path to the checkpoint with the highest global step
- common.pytorch.run_cstorch_flow.optimizer_step_with_summaries(loss: torch.Tensor, optimizer: cstorch.optim.Optimizer, grad_scaler: cstorch.amp.GradScaler, max_gradient_norm: float = None, max_gradient_value: float = None, log_summaries: bool = False, model: torch.nn.Module = None)[source]#
Customized equivalent to cstorch.amp.optimizer_step additionally featuring grad norm summaries
- common.pytorch.run_cstorch_flow.run_cstorch_eval(params, model_fn, input_fn, cs_config)[source]#
Runs the evaluatiion workflow built using the cstorch API
- Parameters
params – the params dictionary extracted from the params.yaml used
model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module
input_data_fn – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader
- common.pytorch.run_cstorch_flow.run_cstorch_flow(params, model_fn, train_data_fn, eval_data_fn)[source]#
Set up the cstorch run and call the appropriate helper based on the mode
- Parameters
params – the params dictionary extracted from the params.yaml used
model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module
train_data_fn – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader
eval_data_fn – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader
- common.pytorch.run_cstorch_flow.run_cstorch_train(params, model_fn, input_fn, cs_config)[source]#
Runs the training workflow built using the cstorch API
- Parameters
params – the params dictionary extracted from the params.yaml used
model_fn – A callable that takes in the params dictionary and returns a torch.nn.Module
input_data_fn – A callable that takes in the param dictionary and returns a torch.utils.data.DataLoader
common.pytorch.run_utils module#
Utilities for running Cerebras Pytorch Models
- common.pytorch.run_utils.arg_filter(arg: str, keyword: str) bool [source]#
Checks if a given arg matches the given keyword
- common.pytorch.run_utils.main(params: Dict[str, Any], model_fn: Callable[[dict], torch.nn.Module], train_data_fn: Optional[Callable[[dict], torch.utils.data.DataLoader]] = None, eval_data_fn: Optional[Callable[[dict], torch.utils.data.DataLoader]] = None, script: Optional[str] = None, extra_args_parser_fn: Optional[Callable[[], List[argparse.ArgumentParser]]] = None)[source]#
Entry point to running pytorch models
- common.pytorch.run_utils.run(model_fn: Callable[[dict], torch.nn.Module], train_data_fn: Optional[Callable[[dict], torch.utils.data.DataLoader]] = None, eval_data_fn: Optional[Callable[[dict], torch.utils.data.DataLoader]] = None, default_params_fn: Optional[Callable[[dict], dict]] = None, extra_args_parser_fn: Optional[Callable[[], List[argparse.ArgumentParser]]] = None)[source]#
Backward compatible entry point to running pytorch models
- common.pytorch.run_utils.run_base_model_flow(params, model_fn, train_data_fn, eval_data_fn)[source]#
Runs PytorchBaseModel and Runner flow
- common.pytorch.run_utils.run_with_params(params: Dict[str, Any], model_fn: Callable[[dict], torch.nn.Module], train_data_fn: Optional[Callable[[dict], torch.utils.data.DataLoader]] = None, eval_data_fn: Optional[Callable[[dict], torch.utils.data.DataLoader]] = None, extra_args_parser_fn: Optional[Callable[[], List[argparse.ArgumentParser]]] = None)[source]#
Runs a full end-to-end CS/non-CS workflow for a given model
- Parameters
model_fn – A callable that takes in a ‘params’ argument which it uses to configure and return a torch.nn.Module
train_data_fn – A callable that takes in a ‘params’ argument which it uses to configure and return a PyTorch dataloader corresponding to the training dataset
eval_data_fn – A callable that takes in a ‘params’ argument which it uses to configure and return a PyTorch dataloader corresponding to the evaluation dataset
default_params_fn – An optional callable that takes in the params dictionary and updates any missing params with default values
extra_args_parser_fn – An optional callable that adds any extra parser args not covered in get_parser fn.
- common.pytorch.run_utils.sideband_eval_all(filename: str, arguments: List[str], params: Dict[Any, Any])[source]#
Temporary support for running eval multiple times via subprocess
common.pytorch.utils module#
General purpose Pytorch Utilities
- class common.pytorch.utils.BufferedShuffleDataset[source]#
Bases:
torch.utils.data.IterableDataset
Dataset shuffled from the original dataset.
This class is useful to shuffle an existing instance of an IterableDataset. The buffer with buffer_size is filled with the items from the dataset first. Then, each item will be yielded from the buffer by reservoir sampling via iterator. buffer_size is required to be larger than 0. For buffer_size == 1, the dataset is not shuffled. In order to fully shuffle the whole dataset, buffer_size is required to be greater than or equal to the size of dataset. When it is used with
DataLoader
, each item in the dataset will be yielded from theDataLoader
iterator. And, the method to set up a random seed is different based onnum_workers
. For single-process mode (num_workers == 0
), the random seed is required to be set before theDataLoader
in the main process.- Parameters
dataset (IterableDataset) – The original IterableDataset.
buffer_size (int) – The buffer size for shuffling.
Example
For multi-process mode (
num_workers > 0
), the random seed is set by a callable function in each worker.>>> ds = BufferedShuffleDataset(dataset) >>> random.seed(...) >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) >>> ds = BufferedShuffleDataset(dataset) >>> def init_fn(worker_id): ... random.seed(...) >>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn)))
- class common.pytorch.utils.IterableDatasetSampler[source]#
Bases:
torch.utils.data.IterableDataset
This sampler can be used with a multi-worker distributed dataloader. All workers on all nodes get a copy of the IterableDataset but only yield samples according to the world size and their rank.
- class common.pytorch.utils.SampleGenerator[source]#
Bases:
object
Iterator which returns multiple samples of a given input data.
Can be used in place of a PyTorch DataLoader to generate synthetic data.
- Parameters
data – The data which should be returned at each iterator step.
sample_count – The maximum number of data samples to be returned.
- common.pytorch.utils.get_adaptive_lr_layers(model, lr_adjustment_layer_type)[source]#
- Parameters
model – Pytorch model
list (lr_adjustment_layer_type) – type of layer for which lr scaler is provided
- Returns
list of layer names for the given lr_adjustment_layer_type
- Return type
list
- common.pytorch.utils.get_checkpoints(model_dir: str) List[str] [source]#
Gather checkpoints in a model directory
- common.pytorch.utils.get_input_dtype(to_float16: bool)[source]#
Determine input datatype based on environment
- common.pytorch.utils.monkeypatch_grad_scaler_step_if_finite()[source]#
Add torch.cuda.amp.GradScaler.step_if_finite API to match cbtorch.amp.GradScaler.
- common.pytorch.utils.named_parameters_requiring_grad(model)[source]#
Returns the named paramters that should be passed to the optimizer i.e. are trainable because they require gradients.
- common.pytorch.utils.partition_params_groups_with_adjusted_lr(model, param_optimizer_grouped, lr_adjustment_layers, lr_adjustment_scalars)[source]#
Generates param_groups based on the lr_adjustment_layers Each lr adjustment layer_type will have a group asociated with it.
- Parameters
model – Pytorch model
param_optimizer_grouped (list) – param_groups before the split based on lr_adjustment_layers
lr_adjustment_layers (list) – list of layer types with different lr adjustment scalars
lr_adjustment_scalars (list) – lr adjustment scalars
- Returns
list of dicts of param groups
- Return type
list
- common.pytorch.utils.partition_params_groups_with_weight_decay(model, param_groups, weight_decay_rate)[source]#
- Parameters
model – Pytorch model
param_groups (list) – optimizer param_groups. Currently it will be just 1 group
weight_decay_rate (float) – value of weight decay rate from yaml
- Returns
param_groups as list of dicts, split based on the weight_decay rate
- Return type
list
- common.pytorch.utils.setup_logging(chief_logging_level: str, streamer_logging_level: str, logging_dir: Optional[str] = None)[source]#
Configure default logging format
- common.pytorch.utils.should_apply_weight_decay(model, param_name)[source]#
- Parameters
model – Pytorch model
param_name (torch.nn.Parameter) – model param name
- Returns
whether to apply weight decay for the give param_name
- Return type
bool
- common.pytorch.utils.to_tensor(value, device=None)[source]#
If the provided value is a Python int or float, it converts them into PyTorch Tensors of type int32 and float32 respectively. Otherwise, it just returns the value.
- common.pytorch.utils.visit_structure(data_structure: Union[Any, list, tuple, dict], select_fn: Callable[[Any], bool], strict: bool = False, scope: Optional[List[str]] = None) Generator[Tuple[List[str], Any], None, None] [source]#
Recursively traverse nested structure and return the items accepted by the selector.
- Parameters
data_structure – A nested data structure to traverse recursively.
select_fn – A callable that returns true if the item passed should be selected.
strict – Strictly checks that an item in the nested structure is either a list/dict/tuple or selected by the select_fn. Otherwise, raises an error. Defaults to False.
scope – The current hierarchical scope of the data structure. Defaults to None.
- Yields
A tuples of (scope, item) for each item selected by the select_fn.