Skip to content

Generate perturbation effects

Generate perturbation effects for genes.

If max_mean_adjustment is greater than 0, then the mean of the effects are adjusted based on the binding_data and the function passed in adjustment_function. See default_perturbation_effect_adjustment_function() for the default option. If max_mean_adjustment is 0, then the mean is not adjusted. Additional keyword arguments may be passed in that will be passed along to the adjustment function.

Parameters:

Name Type Description Default
binding_data Tensor

A tensor of binding data with dimensions [n_genes, n_tfs, 3] where the entries in the third dimension are a matrix with columns [label, enrichment, pvalue].

required
tf_index int | None

The index of the TF in the binding_data tensor. Not used if we are adjusting the means (ie only used if max_mean_adjustment == 0). Defaults to None

None
unbound_mean float

The mean for unbound genes. Defaults to 0.0

0.0
unbound_std float

The standard deviation for unbound genes. Defaults to 1.0

1.0
bound_mean float

The mean for bound genes. Defaults to 3.0

3.0
bound_std float

The standard deviation for bound genes. Defaults to 1.0

1.0
max_mean_adjustment float

The maximum adjustment to the base mean based on enrichment. Defaults to 0.0

0.0

Returns:

Type Description
torch.Tensor

A tensor of perturbation effects for each gene.

Raises:

Type Description
ValueError

If binding_data is not a 3D tensor with the third dimension having a length of 3

ValueError

If unbound_mean, unbound_std, bound_mean, bound_std, or max_mean_adjustment are not floats

Source code in yeastdnnexplorer/probability_models/generate_data.py
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def generate_perturbation_effects(
    binding_data: torch.Tensor,
    tf_index: int | None = None,
    unbound_mean: float = 0.0,
    unbound_std: float = 1.0,
    bound_mean: float = 3.0,
    bound_std: float = 1.0,
    max_mean_adjustment: float = 0.0,
    adjustment_function: Callable[
        [torch.Tensor, float, float, float], torch.Tensor
    ] = default_perturbation_effect_adjustment_function,
    **kwargs,
) -> torch.Tensor:
    """
    Generate perturbation effects for genes.

    If `max_mean_adjustment` is greater than 0, then the mean of the
    effects are adjusted based on the binding_data and the function passed
    in `adjustment_function`. See `default_perturbation_effect_adjustment_function()`
    for the default option. If `max_mean_adjustment` is 0, then the mean
    is not adjusted. Additional keyword arguments may be passed in that will be
    passed along to the adjustment function.

    :param binding_data: A tensor of binding data with dimensions [n_genes, n_tfs, 3]
        where the entries in the third dimension are a matrix with columns
        [label, enrichment, pvalue].
    :type binding_data: torch.Tensor
    :param tf_index: The index of the TF in the binding_data tensor. Not used if we
        are adjusting the means (ie only used if max_mean_adjustment == 0).
        Defaults to None
    :type tf_index: int
    :param unbound_mean: The mean for unbound genes. Defaults to 0.0
    :type unbound_mean: float, optional
    :param unbound_std: The standard deviation for unbound genes. Defaults to 1.0
    :type unbound_std: float, optional
    :param bound_mean: The mean for bound genes. Defaults to 3.0
    :type bound_mean: float, optional
    :param bound_std: The standard deviation for bound genes. Defaults to 1.0
    :type bound_std: float, optional
    :param max_mean_adjustment: The maximum adjustment to the base mean based
        on enrichment. Defaults to 0.0
    :type max_mean_adjustment: float, optional

    :return: A tensor of perturbation effects for each gene.
    :rtype: torch.Tensor

    :raises ValueError: If binding_data is not a 3D tensor with the third
        dimension having a length of 3
    :raises ValueError: If unbound_mean, unbound_std, bound_mean, bound_std,
        or max_mean_adjustment are not floats

    """
    # check that a valid combination of inputs has been passed in
    if max_mean_adjustment == 0.0 and tf_index is None:
        raise ValueError("If max_mean_adjustment is 0, then tf_index must be specified")

    if binding_data.ndim != 3 or binding_data.shape[2] != 3:
        raise ValueError(
            "enrichment_tensor must have dimensions [num_genes, num_TFs, "
            "[label, enrichment, pvalue]]"
        )
    # check the rest of the inputs
    if not all(
        isinstance(i, float)
        for i in (unbound_mean, unbound_std, bound_mean, bound_std, max_mean_adjustment)
    ):
        raise ValueError(
            "unbound_mean, unbound_std, bound_mean, bound_std, "
            "and max_mean_adjustment must be floats"
        )
    # check the Callable signature
    if not all(
        i in inspect.signature(adjustment_function).parameters
        for i in (
            "binding_enrichment_data",
            "bound_mean",
            "unbound_mean",
            "max_adjustment",
        )
    ):
        raise ValueError(
            "adjustment_function must have the signature "
            "(binding_enrichment_data, bound_mean, unbound_mean, max_adjustment)"
        )

    # Initialize an effects tensor for all genes
    effects = torch.empty(
        binding_data.size(0), dtype=torch.float32, device=binding_data.device
    )

    # Randomly assign signs for each gene
    # fmt: off
    signs = torch.randint(0, 2, (effects.size(0),),
                          dtype=torch.float32,
                          device=binding_data.device) * 2 - 1
    # fmt: on

    # Apply adjustments to the base mean for the bound genes, if necessary
    if max_mean_adjustment > 0 and adjustment_function is not None:
        # Assuming adjustment_function returns a vector of means for each gene.
        # bound genes that meet the criteria for adjustment will be affected by
        # the status of the TFs. What TFs affect a given gene must be specified by
        # the adjustment_function()
        adjusted_means = adjustment_function(
            binding_data,
            bound_mean,
            unbound_mean,
            max_mean_adjustment,
            **kwargs,
        )

        # add adjustments, ensuring they respect the original sign
        if adjusted_means.ndim == 1:
            effects = signs * torch.abs(
                torch.normal(mean=adjusted_means, std=bound_std)
            )
        else:
            effects = torch.zeros_like(adjusted_means)
            for col_idx in range(effects.size(1)):
                effects[:, col_idx] = signs * torch.abs(
                    torch.normal(mean=adjusted_means[:, col_idx], std=bound_std)
                )
    else:
        bound_mask = binding_data[:, tf_index, 0] == 1

        # Generate effects based on the unbound and bound means, applying the sign
        effects[~bound_mask] = signs[~bound_mask] * torch.abs(
            torch.normal(
                mean=unbound_mean, std=unbound_std, size=(torch.sum(~bound_mask),)
            )
        )
        effects[bound_mask] = signs[bound_mask] * torch.abs(
            torch.normal(mean=bound_mean, std=bound_std, size=(torch.sum(bound_mask),))
        )

    return effects