ReferentialGym.utils package¶
Subpackages¶
Submodules¶
ReferentialGym.utils.utils module¶
-
ReferentialGym.utils.utils.
gumbel_softmax
(logits, tau=1, hard=False, eps=1e-10, dim=- 1)¶ Samples from the Gumbel-Softmax distribution and optionally discretizes.
- Args:
logits: […, num_features] unnormalized log probabilities tau: non-negative scalar temperature hard: if
True
, the returned samples will be discretized as one-hot vectors,but will be differentiated as if it is the soft sample in autograd
dim (int): A dimension along which softmax will be computed. Default: -1.
- Returns:
Sampled tensor of same shape as logits from the Gumbel-Softmax distribution. If
hard=True
, the returned samples will be one-hot, otherwise they will be probability distributions that sum to 1 across dim.
Note
The main trick for hard is to do y_hard - y_soft.detach() + y_soft
It achieves two things: - makes the output value exactly one-hot (since we add then subtract y_soft value) - makes the gradient equal to y_soft gradient (since we strip all other gradients)
-
class
ReferentialGym.utils.utils.
StraightThroughGumbelSoftmaxLayer
(input_dim, inv_tau0=0.5)¶ Bases:
torch.nn.modules.module.Module
-
forward
(logits, param)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
ReferentialGym.utils.utils.
cardinality
(data)¶
-
ReferentialGym.utils.utils.
compute_levenshtein_distance
(s1, s2)¶
-
ReferentialGym.utils.utils.
compute_cosine_sim
(v1, v2)¶
-
ReferentialGym.utils.utils.
compute_levenshtein_distance_for_idx_over_comprange
(sentences, idx, comprange)¶
-
ReferentialGym.utils.utils.
compute_cosine_sim_for_idx_over_comprange
(features, idx, comprange)¶
-
ReferentialGym.utils.utils.
compute_topographic_similarity_parallel
(sentences, features, comprange=100, max_workers=32)¶
-
ReferentialGym.utils.utils.
query_vae_latent_space
(omodel, sample, path, test=False, full=True, idxoffset=None, suffix='', use_cuda=False)¶