Skip to content

Metrics smse

Bases: Metric

A class for computing the standardized mean squared error (SMSE) metric.

This metric is defined as the mean squared error divided by the variance of the true values (the target data). Because we are dividing by the variance of the true values, this metric is scale-independent and does not depend on the mean of the true values. It allows us to effectively compare models drawn from different datasets with differring scales or means (as long as their variances are at least relatively similar)

Source code in yeastdnnexplorer/ml_models/metrics.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class SMSE(Metric):
    """
    A class for computing the standardized mean squared error (SMSE) metric.

    This metric is defined as the mean squared error divided by the variance of the true
    values (the target data). Because we are dividing by the variance of the true
    values, this metric is scale-independent and does not depend on the mean of the true
    values. It allows us to effectively compare models drawn from different datasets
    with differring scales or means (as long as their variances are at least relatively
    similar)

    """

    def __init__(self):
        """Initialize the SMSE metric."""
        super().__init__()
        self.add_state("mse", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("variance", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, y_pred: torch.Tensor, y_true: torch.Tensor):
        """
        Update the metric with new predictions and true values.

        :param y_pred: The predicted y values
        :type y_pred: torch.Tensor
        :param y_true: The true y values
        :type y_true: torch.Tensor

        """
        self.mse += F.mse_loss(y_pred, y_true, reduction="sum")
        self.variance += torch.var(y_true, unbiased=False) * y_true.size(
            0
        )  # Total variance (TODO should we have unbiased=False here?)
        self.num_samples += y_true.numel()

    def compute(self):
        """
        Compute the SMSE metric.

        :return: The SMSE metric
        :rtype: torch.Tensor

        """
        mean_mse = self.mse / self.num_samples
        mean_variance = self.variance / self.num_samples
        return mean_mse / mean_variance

__init__()

Initialize the SMSE metric.

Source code in yeastdnnexplorer/ml_models/metrics.py
19
20
21
22
23
24
def __init__(self):
    """Initialize the SMSE metric."""
    super().__init__()
    self.add_state("mse", default=torch.tensor(0.0), dist_reduce_fx="sum")
    self.add_state("variance", default=torch.tensor(0.0), dist_reduce_fx="sum")
    self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="sum")

compute()

Compute the SMSE metric.

Returns:

Type Description
torch.Tensor

The SMSE metric

Source code in yeastdnnexplorer/ml_models/metrics.py
42
43
44
45
46
47
48
49
50
51
52
def compute(self):
    """
    Compute the SMSE metric.

    :return: The SMSE metric
    :rtype: torch.Tensor

    """
    mean_mse = self.mse / self.num_samples
    mean_variance = self.variance / self.num_samples
    return mean_mse / mean_variance

update(y_pred, y_true)

Update the metric with new predictions and true values.

Parameters:

Name Type Description Default
y_pred Tensor

The predicted y values

required
y_true Tensor

The true y values

required
Source code in yeastdnnexplorer/ml_models/metrics.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def update(self, y_pred: torch.Tensor, y_true: torch.Tensor):
    """
    Update the metric with new predictions and true values.

    :param y_pred: The predicted y values
    :type y_pred: torch.Tensor
    :param y_true: The true y values
    :type y_true: torch.Tensor

    """
    self.mse += F.mse_loss(y_pred, y_true, reduction="sum")
    self.variance += torch.var(y_true, unbiased=False) * y_true.size(
        0
    )  # Total variance (TODO should we have unbiased=False here?)
    self.num_samples += y_true.numel()