ReferentialGym.networks package

Submodules

ReferentialGym.networks.autoregressive_networks module

class ReferentialGym.networks.autoregressive_networks.Distribution

Bases: object

sample()
log_prob(values)
class ReferentialGym.networks.autoregressive_networks.Bernoulli(probs)

Bases: ReferentialGym.networks.autoregressive_networks.Distribution

sample()
log_prob(values)
class ReferentialGym.networks.autoregressive_networks.Normal(mean, std)

Bases: ReferentialGym.networks.autoregressive_networks.Distribution

sample()
log_prob(value)
class ReferentialGym.networks.autoregressive_networks.ResNetEncoder(input_shape, latent_dim=32, pretrained=False, nbr_layer=4, use_coordconv=False)

Bases: ReferentialGym.networks.residual_networks.ModelResNet18

get_feature_shape()
encode(x)
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.ResNetAvgPooledEncoder(input_shape, latent_dim=32, pretrained=False, nbr_layer=4, use_coordconv=False)

Bases: ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled

get_feature_shape()
encode(x)
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.ResNetParallelAttentionEncoder(input_shape, latent_dim=10, nbr_attention_slot=10, pretrained=False, nbr_layer=4, use_coordconv=False)

Bases: ReferentialGym.networks.residual_networks.ModelResNet18

get_feature_shape()
encode(x)
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.addXYSfeatures(nbr_attention_slot=10)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.ResNetPHDPAEncoder(input_shape, latent_dim=10, nbr_attention_slot=10, pretrained=False, nbr_layer=4, use_coordconv=False)

Bases: ReferentialGym.networks.residual_networks.ModelResNet18

get_feature_shape()
encode(x)
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.Decoder(output_shape=[3, 64, 64], net_depth=3, latent_dim=32, conv_dim=64)

Bases: torch.nn.modules.module.Module

decode(z)
forward(z)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.BroadcastingDecoder(output_shape=[3, 64, 64], net_depth=3, kernel_size=3, stride=1, padding=1, latent_dim=32, conv_dim=64)

Bases: torch.nn.modules.module.Module

decode(z)
forward(z)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.BroadcastingDeconvDecoder(output_shape=[3, 64, 64], net_depth=3, latent_dim=32, conv_dim=64)

Bases: torch.nn.modules.module.Module

decode(z)
forward(z)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.ParallelAttentionBroadcastingDeconvDecoder(output_shape=[3, 64, 64], net_depth=3, latent_dim=32, nbr_attention_slot=10, conv_dim=64)

Bases: torch.nn.modules.module.Module

decode(z)
forward(z)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.TotalCorrelationDiscriminator(VAE)

Bases: object

update(z, train=True)
step()
permutate_latents(z)
class ReferentialGym.networks.autoregressive_networks.BetaVAE(beta=10000.0, encoder=None, decoder=None, latent_dim=32, nbr_attention_slot=None, input_shape=[3, 64, 64], NormalOutputDistribution=True, maxEncodingCapacity=1000, nbrEpochTillMaxEncodingCapacity=4, constrainedEncoding=True, observation_sigma=0.05, factor_vae_gamma=0.0)

Bases: torch.nn.modules.module.Module

get_feature_shape()
_compute_feature_shape(input_dim=None, nbr_layer=None)
reparameterize(mu, log_var)
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

encode(x)
encodeZ(x)
decode(z)
_forward(x=None, evaluation=False, fixed_latent=None, data=None)
get_feat_map()
compute_loss(x=None, fixed_latent=None, data=None, evaluation=False, observation_sigma=None)
class ReferentialGym.networks.autoregressive_networks.UNetBlock(in_channel, out_channel, upsample=True, interpolate=False, interpolation_factor=2, batch_norm=False)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.UNet(input_shape, in_channel, out_channel, basis_nbr_channel=32, block_depth=3, batch_norm=False)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.AttentionNetwork(input_shape, in_channel, attention_basis_nbr_channel=32, attention_block_depth=3)

Bases: torch.nn.modules.module.Module

in_channel = None
self.unet = UNet(input_shape=self.input_shape,

in_channel=self.in_channel, out_channel=1, basis_nbr_channel=attention_basis_nbr_channel, block_depth=attention_block_depth)

forward(x, logscope)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.ParallelAttentionNetwork(input_shape, in_channel, nbr_attention_slot=10, attention_basis_nbr_channel=32, attention_block_depth=3)

Bases: torch.nn.modules.module.Module

forward(x, logscope)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.autoregressive_networks.MONet(gamma=0.5, input_shape=[3, 64, 64], nbr_attention_slot=10, anet_basis_nbr_channel=32, anet_block_depth=3, cvae_beta=0.5, cvae_latent_dim=10, cvae_decoder_conv_dim=32, cvae_pretrained=False, cvae_resnet_encoder=False, cvae_resnet_nbr_layer=2, cvae_decoder_nbr_layer=3, cvae_EncodingCapacityStep=None, cvae_maxEncodingCapacity=100, cvae_nbrEpochTillMaxEncodingCapacity=4, cvae_constrainedEncoding=True, cvae_observation_sigma=0.05, compactness_factor=None)

Bases: ReferentialGym.networks.autoregressive_networks.BetaVAE

get_feature_shape()
encodeZ(x)
decode(z)
forward(x, observation_sigma=None, compute_loss=False)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

compute_loss(x=None, observation_sigma=None)
class ReferentialGym.networks.autoregressive_networks.ParallelMONet(gamma=0.5, input_shape=[3, 64, 64], nbr_attention_slot=10, anet_basis_nbr_channel=32, anet_block_depth=3, cvae_beta=0.5, cvae_latent_dim=10, cvae_decoder_conv_dim=32, cvae_pretrained=False, cvae_resnet_encoder=False, cvae_resnet_nbr_layer=2, cvae_decoder_nbr_layer=3, cvae_EncodingCapacityStep=None, cvae_maxEncodingCapacity=100, cvae_nbrEpochTillMaxEncodingCapacity=4, cvae_constrainedEncoding=True, cvae_observation_sigma=0.05, compactness_factor=None)

Bases: ReferentialGym.networks.autoregressive_networks.BetaVAE

get_feature_shape()
encodeZ(x)
decode(z)
forward(x, observation_sigma=None, compute_loss=False)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

compute_loss(x=None, observation_sigma=None)

ReferentialGym.networks.homoscedastic_multitask_loss module

class ReferentialGym.networks.homoscedastic_multitask_loss.HomoscedasticMultiTasksLoss(nbr_tasks=2, use_cuda=False)

Bases: torch.nn.modules.module.Module

forward(loss_dict)
Parameters

loss_dict – Dict[str, Tuple(float, torch.Tensor)] that associates loss names with their pair of (linear coefficient, loss), where the loss is in batched shape: (batch_size, 1)

ReferentialGym.networks.networks module

ReferentialGym.networks.networks.retrieve_output_shape(input, model)
ReferentialGym.networks.networks.hasnan(tensor)
ReferentialGym.networks.networks.handle_nan(layer, verbose=True)
ReferentialGym.networks.networks.layer_init(layer, w_scale=1.0)
class ReferentialGym.networks.networks.addXYfeatures

Bases: torch.nn.modules.module.Module

forward(x, outputFsizes=False)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.networks.addXYRhoThetaFeatures

Bases: torch.nn.modules.module.Module

forward(x, outputFsizes=False)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ReferentialGym.networks.networks.conv(sin, sout, k, stride=1, padding=0, batchNorm=True)
ReferentialGym.networks.networks.conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1)

3x3 convolution with padding

ReferentialGym.networks.networks.conv1x1(in_planes, out_planes, stride=1)

1x1 convolution

ReferentialGym.networks.networks.deconv(sin, sout, k, stride=1, padding=0, batchNorm=True)
ReferentialGym.networks.networks.coordconv(sin, sout, kernel_size, stride=1, padding=0, batchNorm=False, bias=True, groups=1, dilation=1)
ReferentialGym.networks.networks.coordconv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1)

3x3 coord convolution with padding

ReferentialGym.networks.networks.coordconv1x1(in_planes, out_planes, stride=1)

1x1 coord convolution

ReferentialGym.networks.networks.coorddeconv(sin, sout, kernel_size, stride=2, padding=1, batchNorm=True, bias=False)
ReferentialGym.networks.networks.coord4conv(sin, sout, kernel_size, stride=1, padding=0, batchNorm=False, bias=True, groups=1, dilation=1)
ReferentialGym.networks.networks.coord4conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1)

3x3 coord convolution with padding

ReferentialGym.networks.networks.coord4conv1x1(in_planes, out_planes, stride=1)

1x1 coord convolution

ReferentialGym.networks.networks.coord4deconv(sin, sout, kernel_size, stride=2, padding=1, batchNorm=True, bias=False)
class ReferentialGym.networks.networks.FCBody(state_dim, hidden_units=(64, 64), gate=<function relu>)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_feature_shape()
class ReferentialGym.networks.networks.ConvolutionalBody(input_shape, feature_dim=256, channels=[3, 3], kernel_sizes=[1], strides=[1], paddings=[0], fc_hidden_units=None, dropout=0.0, non_linearities=[<class 'torch.nn.modules.activation.LeakyReLU'>], use_coordconv=None)

Bases: torch.nn.modules.module.Module

_compute_feat_map(x)
get_feat_map()
forward(x, non_lin_output=True)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_input_shape()
get_feature_shape()
_compute_feature_shape(input_dim=None, nbr_layer=None)
class ReferentialGym.networks.networks.EntityPrioredConvolutionalBody(input_shape, feature_dim=256, channels=[3, 3], kernel_sizes=[1], strides=[1], paddings=[0], fc_hidden_units=None, dropout=0.0, non_linearities=[<class 'torch.nn.modules.activation.LeakyReLU'>], use_coordconv=None)

Bases: ReferentialGym.networks.networks.ConvolutionalBody

_compute_feat_map(x)
class ReferentialGym.networks.networks.ConvolutionalLstmBody(input_shape, feature_dim=256, channels=[3, 3], kernel_sizes=[1], strides=[1], paddings=[0], fc_hidden_units=None, rnn_hidden_units=(256,), dropout=0.0, non_linearities=[<class 'torch.nn.modules.activation.ReLU'>], gate=<function relu>, use_coordconv=None)

Bases: ReferentialGym.networks.networks.ConvolutionalBody

forward(inputs)
Parameters

inputs – input to LSTM cells. Structured as (feed_forward_input, {hidden: hidden_states, cell: cell_states}).

hidden_states: list of hidden_state(s) one for each self.layers. cell_states: list of hidden_state(s) one for each self.layers.

get_reset_states(cuda=False, repeat=1)
get_input_shape()
get_feature_shape()
class ReferentialGym.networks.networks.ConvolutionalGruBody(input_shape, feature_dim=256, channels=[3, 3], kernel_sizes=[1], strides=[1], paddings=[0], fc_hidden_units=None, rnn_hidden_units=(256,), dropout=0.0, non_linearities=[<class 'torch.nn.modules.activation.ReLU'>], gate=<function relu>, use_coordconv=None)

Bases: ReferentialGym.networks.networks.ConvolutionalBody

forward(inputs)
Parameters

inputs – input to GRU cells. Structured as (feed_forward_input, {hidden: hidden_states, cell: cell_states}).

hidden_states: list of hidden_state(s) one for each self.layers. cell_states: list of hidden_state(s) one for each self.layers.

get_reset_states(cuda=False, repeat=1)
get_input_shape()
get_feature_shape()
class ReferentialGym.networks.networks.LSTMBody(state_dim, rnn_hidden_units=256, gate=<function relu>)

Bases: torch.nn.modules.module.Module

forward(inputs)
Parameters

inputs – input to LSTM cells. Structured as (feed_forward_input, {hidden: hidden_states, cell: cell_states}).

hidden_states: list of hidden_state(s) one for each self.layers. cell_states: list of hidden_state(s) one for each self.layers.

get_reset_states(cuda=False, repeat=1)
get_feature_shape()
class ReferentialGym.networks.networks.GRUBody(state_dim, rnn_hidden_units=256, gate=<function relu>)

Bases: torch.nn.modules.module.Module

forward(inputs)
Parameters

inputs – input to LSTM cells. Structured as (feed_forward_input, {hidden: hidden_states, cell: cell_states}).

hidden_states: list of hidden_state(s) one for each self.layers. cell_states: list of hidden_state(s) one for each self.layers.

get_reset_states(cuda=False, repeat=1)
get_feature_shape()
class ReferentialGym.networks.networks.MHDPA(depth_dim=37, interactions_dim=64, hidden_size=256)

Bases: torch.nn.modules.module.Module

forward(x, usef=False)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

save(path)
load(path)
class ReferentialGym.networks.networks.MHDPA_RN(depth_dim=35, nbrHead=3, nbrRecurrentSharedLayers=1, nbrEntity=7, units_per_MLP_layer=256, interactions_dim=128, output_dim=None, dropout_prob=0.0, use_coord4=False)

Bases: torch.nn.modules.module.Module

forwardScaledDPAhead(x, head, reset_hidden_states=False)
forwardStackedMHDPA(augx)
forward(x=None, augx=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.networks.ConvolutionalMHDPABody(input_shape, feature_dim=256, channels=[3, 3], kernel_sizes=[1], strides=[1], paddings=[0], fc_hidden_units=None, dropout=0.0, non_linearities=[<class 'torch.nn.modules.activation.LeakyReLU'>], use_coordconv=None, nbrHead=4, nbrRecurrentSharedLayers=1, units_per_MLP_layer=512, interaction_dim=128, use_coord4=False)

Bases: ReferentialGym.networks.networks.ConvolutionalBody

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.networks.VGG(features, num_classes=1000, init_weights=True)

Bases: torch.nn.modules.module.Module

Making the VGG architecture usable as a classification-layer-free convolutional architecture to choose from.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

_initialize_weights()
ReferentialGym.networks.networks._vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs)
class ReferentialGym.networks.networks.ModelVGG16(input_shape, feature_dim=512, pretrained=True, final_layer_idx=None)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_feature_shape()
class ReferentialGym.networks.networks.ExtractorVGG16(input_shape, final_layer_idx=None, pretrained=True)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ReferentialGym.networks.residual_networks module

class ReferentialGym.networks.residual_networks.ResNet(block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None)

Bases: torch.nn.modules.module.Module

_make_layer(block, planes, blocks, stride=1, dilate=False)
_forward_impl(x)
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.residual_networks.CoordResNet(block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None)

Bases: torch.nn.modules.module.Module

_make_layer(block, planes, blocks, stride=1, dilate=False)
_forward_impl(x)
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.residual_networks.ModelResNet18(input_shape, feature_dim=256, nbr_layer=None, pretrained=False, use_coordconv=False)

Bases: torchvision.models.resnet.ResNet

_compute_feature_shape(input_dim, nbr_layer)
_compute_feat_map(x)
get_feat_map()
_compute_features(features_map)
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_feature_shape()
class ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled(input_shape, feature_dim=256, nbr_layer=None, pretrained=False, detach_conv_maps=False, use_coordconv=False)

Bases: torchvision.models.resnet.ResNet

_compute_feature_shape(input_dim=None, nbr_layer=None)
_compute_feat_map(x)
_compute_features(features_map)
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_feat_map()
get_feature_shape()
class ReferentialGym.networks.residual_networks.ResNet18MHDPA(input_shape, feature_dim=256, nbr_layer=None, pretrained=False, use_coordconv=False, dropout=0.0, non_linearities=[<class 'torch.nn.modules.activation.LeakyReLU'>], nbrHead=4, nbrRecurrentSharedLayers=1, units_per_MLP_layer=512, interaction_dim=128)

Bases: ReferentialGym.networks.residual_networks.ModelResNet18

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.residual_networks.ResNet18AvgPooledMHDPA(input_shape, feature_dim=256, nbr_layer=None, pretrained=False, detach_conv_maps=False, use_coordconv=False, dropout=0.0, non_linearities=[<class 'torch.nn.modules.activation.LeakyReLU'>], nbrHead=4, nbrRecurrentSharedLayers=1, units_per_MLP_layer=512, interaction_dim=128)

Bases: ReferentialGym.networks.residual_networks.ModelResNet18AvgPooled

_compute_feature_shape(input_dim=None, nbr_layer=None)
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ReferentialGym.networks.residual_networks.ExtractorResNet18(input_shape, final_layer_idx=None, pretrained=True)

Bases: ReferentialGym.networks.residual_networks.ModelResNet18

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents

ReferentialGym.networks.choose_architecture(architecture, kwargs=None, fc_hidden_units_list=None, rnn_hidden_units_list=None, input_shape=None, feature_dim=None, nbr_channels_list=None, kernels=None, strides=None, paddings=None, dropout=0.0, MHDPANbrHead=4, MHDPANbrRecUpdate=1, MHDPANbrMLPUnit=512, MHDPAInteractionDim=128)