prompting.validators.gating#

Module Contents#

Classes#

BaseGatingModel

This class is an abstract base class for the gating model. It defines the interface for the gating model.

GatingModel

This class is a PyTorch module that encapsulates the gating model functionality.

SentenceEmbedGatingModel

This class is a PyTorch module that encapsulates a custom version of a gating model based on sentence transformers.

class prompting.validators.gating.BaseGatingModel#

Bases: torch.nn.Module, abc.ABC

This class is an abstract base class for the gating model. It defines the interface for the gating model.

classmethod add_args(parser)#

Adds command line arguments to the parser that are used to configure the gating model. The arguments added are: - –gating.model_name: Name of the pre-trained transformer-based language model to use as the encoding layer

for the gating model. (default: ‘EleutherAI/gpt-neo-125m’)

  • –gating.num_uids: Number of uids to gate on. (default: 4096)

  • –gating.learning_rate: Learning rate for the gating model optimizer. (default: 0.01)

  • –gating.momentum: Momentum for the gating model optimizer. (default: 0.9)

Parameters:

parser (argparse.ArgumentParser) –

abstract forward(message)#

Forward pass through the gating model

Parameters:

message (str) –

Return type:

torch.FloatTensor

abstract backward(scores, rewards)#

Backward pass through the gating model

Parameters:
  • scores (torch.FloatTensor) –

  • rewards (torch.FloatTensor) –

abstract resync(previous_metagraph, metagraph)#

Resync the gating model with the latest state of the network Args: previous_metagraph (:obj: bt.metagraph.Metagraph):

Previous state of metagraph before updated resync

metagraph (:obj: bt.metagraph.Metagraph):

Latest state of the metagraph with updated uids and hotkeys

Parameters:
  • previous_metagraph (bittensor.metagraph.Metagraph) –

  • metagraph (bittensor.metagraph.Metagraph) –

classmethod config()#

Returns a configuration object that contains the command line arguments for the gating model.

classmethod check_config(config)#

Validates the configuration object for the gating model.

Parameters:

config (bittensor.Config) –

class prompting.validators.gating.GatingModel(metagraph, config=None, model_name=None, num_uids=None)#

Bases: BaseGatingModel

This class is a PyTorch module that encapsulates the gating model functionality.

  • The backward method runs a backward pass through the model using the mean squared error between

the normalized scores and the normalized rewards as the loss function.

  • The forward method runs a forward pass through the model, encoding the input message and generating scores

for each uid in the network. The scores are returned as a tensor.

Parameters:
  • metagraph (bittensor.metagraph.Metagraph) –

  • config (bittensor.config) –

  • model_name (str) –

  • num_uids (int) –

backward(scores, rewards)#

Runs a backward pass through the model. :param scores: Scores for each uids as output by the gating model. :type scores: torch.FloatTensor of shape (metagraph.n) :param rewards: Rewards for each uids as output by the reward model. :type rewards: torch.FloatTensor of shape (metagraph.n)

Parameters:
  • scores (torch.FloatTensor) –

  • rewards (torch.FloatTensor) –

forward(message)#

Runs a forward pass through the model. :param message: text message to be encoded. :type message: str

Returns:

Scores for each uids as output by the gating model.

Return type:

scores (torch.FloatTensor of shape (network_size))

Parameters:

message (str) –

resync(previous_metagraph, metagraph)#

Resync the gating model with the latest state of the network Args: previous_metagraph (:obj: bt.metagraph.Metagraph):

Previous state of metagraph before updated resync

metagraph (:obj: bt.metagraph.Metagraph):

Latest state of the metagraph with updated uids and hotkeys

Parameters:
  • previous_metagraph (bittensor.metagraph.Metagraph) –

  • metagraph (bittensor.metagraph.Metagraph) –

class prompting.validators.gating.SentenceEmbedGatingModel(metagraph, config=None, model_name=None, num_uids=None)#

Bases: BaseGatingModel

This class is a PyTorch module that encapsulates a custom version of a gating model based on sentence transformers.

  • The backward method runs a backward pass through the model using the mean squared error between the normalized

    scores and the normalized rewards as the loss function.

  • The forward method runs a forward pass through the model, encoding the input message and generating scores

    for each uid in the network. The scores are returned as a tensor.

Parameters:
  • metagraph (bittensor.metagraph.Metagraph) –

  • config (bittensor.config) –

  • model_name (str) –

  • num_uids (int) –

mean_pooling(model_output, attention_mask)#

Applies mean pooling to the token embeddings generated by the model. :param model_output: Embedding model output, where the first element contains token embeddings. :type model_output: torch.Tensor :param attention_mask: Attention mask to indicate valid tokens. :type attention_mask: torch.Tensor

Returns:

Mean-pooled representation of the token embeddings.

Return type:

torch.Tensor

Notes

  • The function calculates the mean-pooled representation using the attention mask for valid tokens.

  • Input_mask_expanded is created by expanding the attention mask to match the size of token embeddings.

  • The result is obtained by summing the element-wise multiplication of embeddings and input_mask_expanded, and dividing it by the sum of input_mask_expanded after clamping its values to a minimum of 1e-9.

forward(message)#

Runs a forward pass through the model. :param message: text message to be encoded. :type message: str

Returns:

Scores for each uids as output by the gating model.

Return type:

scores (torch.FloatTensor of shape (network_size))

Parameters:

message (str) –

backward(scores, rewards)#

Runs a backward pass through the model. :param scores: Scores for each uids as output by the gating model. :type scores: torch.FloatTensor of shape (metagraph.n) :param rewards: Rewards for each uids as output by the reward model. :type rewards: torch.FloatTensor of shape (metagraph.n)

Parameters:
  • scores (torch.FloatTensor) –

  • rewards (torch.FloatTensor) –

resync(previous_metagraph, metagraph)#

Resync the gating model with the latest state of the network Args: previous_metagraph (:obj: bt.metagraph.Metagraph):

Previous state of metagraph before updated resync

metagraph (:obj: bt.metagraph.Metagraph):

Latest state of the metagraph with updated uids and hotkeys

Parameters:
  • previous_metagraph (bittensor.metagraph.Metagraph) –

  • metagraph (bittensor.metagraph.Metagraph) –