Skip to content

Perturbation effect adjustment function with tf relationships

Adjust the mean of the perturbation effect based on the enrichment score and the provided relationships between TFs. For each gene, the mean of the TF-gene pair’s perturbation effect will be adjusted if the TF is bound to the gene and all related TFs are also bound to the gene. The adjustment will be a random value not exceeding the maximum adjustment.

Parameters:

Name Type Description Default
binding_enrichment_data Tensor

A tensor of enrichment scores for each gene with dimensions [n_genes, n_tfs, 3] where the entries in the third dimension are a matrix with columns [label, enrichment, pvalue].

required
bound_mean float

The mean for bound genes.

required
unbound_mean float

The mean for unbound genes.

required
max_adjustment float

The maximum adjustment to the base mean based on enrichment.

required
tf_relationships dict[int, list[int]]

A dictionary where the keys are the indices of the TFs and the values are lists of indices of other TFs that are related to the key TF.

required

Returns:

Type Description
torch.Tensor

Adjusted mean as a tensor.

Raises:

Type Description
ValueError

If tf_relationships is not a dictionary between ints and lists of ints

ValueError

If the tf_relationships dict does not have the same number of TFs as the binding_data tensor passed into the function

ValueError

If the tf_relationships dict has any TFs in the values that are not also in the keys or any key or value TFs that are out of bounds for the binding_data tensor

Source code in yeastdnnexplorer/probability_models/generate_data.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def perturbation_effect_adjustment_function_with_tf_relationships(
    binding_enrichment_data: torch.Tensor,
    bound_mean: float,
    unbound_mean: float,
    max_adjustment: float,
    tf_relationships: dict[int, list[int]],
) -> torch.Tensor:
    """
    Adjust the mean of the perturbation effect based on the enrichment score and the
    provided relationships between TFs. For each gene, the mean of the TF-gene pair's
    perturbation effect will be adjusted if the TF is bound to the gene and all related
    TFs are also bound to the gene. The adjustment will be a random value not exceeding
    the maximum adjustment.

    :param binding_enrichment_data: A tensor of enrichment scores for each gene with
        dimensions [n_genes, n_tfs, 3] where the entries in the third dimension are a
        matrix with columns [label, enrichment, pvalue].
    :type binding_enrichment_data: torch.Tensor
    :param bound_mean: The mean for bound genes.
    :type bound_mean: float
    :param unbound_mean: The mean for unbound genes.
    :type unbound_mean: float
    :param max_adjustment: The maximum adjustment to the base mean based on enrichment.
    :type max_adjustment: float
    :param tf_relationships: A dictionary where the keys are the indices of the TFs and
        the values are lists of indices of other TFs that are related to the key TF.
    :type tf_relationships: dict[int, list[int]]
    :return: Adjusted mean as a tensor.
    :rtype: torch.Tensor
    :raises ValueError: If tf_relationships is not a dictionary between ints and lists
        of ints
    :raises ValueError: If the tf_relationships dict does not have the same number of
        TFs as the binding_data tensor passed into the function
    :raises ValueError: If the tf_relationships dict has any TFs in the values that are
        not also in the keys or any key or value TFs that are out of bounds for the
        binding_data tensor

    """
    if (
        not isinstance(tf_relationships, dict)
        or not all(isinstance(v, list) for v in tf_relationships.values())
        or not all(isinstance(k, int) for k in tf_relationships.keys())
        or not all(isinstance(i, int) for v in tf_relationships.values() for i in v)
    ):
        raise ValueError(
            "tf_relationships must be a dictionary between ints and lists of ints"
        )
    if not all(
        k in range(binding_enrichment_data.shape[1]) for k in tf_relationships.keys()
    ) or not all(
        i in range(binding_enrichment_data.shape[1])
        for v in tf_relationships.values()
        for i in v
    ):
        raise ValueError(
            "all keys and values in tf_relationships must be within the \
                  bounds of the binding_data tensor's number of TFs"
        )
    if not len(tf_relationships) == binding_enrichment_data.shape[1]:
        raise ValueError(
            "tf_relationships must have the same number of TFs as the \
                binding_data tensor passed into the function"
        )

    # Extract bound/unbound labels and enrichment scores
    bound_labels = binding_enrichment_data[:, :, 0]  # shape: (num_genes, num_tfs)
    enrichment_scores = binding_enrichment_data[:, :, 1]  # shape: (num_genes, num_tfs)

    # we set all unbound scores to 0, then we will go through and also
    # set any bound scores to unbound_mean if the related tfs are not also bound
    adjusted_mean_matrix = torch.where(
        bound_labels == 1, enrichment_scores, torch.zeros_like(enrichment_scores)
    )  # shape: (num_genes, num_tfs)

    for gene_idx in range(bound_labels.shape[0]):
        for tf_index, related_tfs in tf_relationships.items():
            if bound_labels[gene_idx, tf_index] == 1 and torch.all(
                bound_labels[gene_idx, related_tfs] == 1
            ):
                # OLD: adjustment_multiplier = torch.rand(1)
                # divide its enrichment score by the maximum magnitude possible to
                # create an adjustment multipler that scales with increasing enrichment
                adjustment_multiplier = enrichment_scores[gene_idx, tf_index] / abs(
                    enrichment_scores.max()
                )

                # randomly adjust the gene by some portion of the max adjustment
                adjusted_mean_matrix[gene_idx, tf_index] = bound_mean + (
                    adjustment_multiplier * max_adjustment
                )
            else:
                # related tfs are not all bound, set the enrichment score to unbound
                # mean
                adjusted_mean_matrix[gene_idx, tf_index] = unbound_mean

    return adjusted_mean_matrix  # shape (num_genes, num_tfs)