Skip to content

InteractorDiagnosticPlot

InteractorDiagnosticPlot

Source code in yeastdnnexplorer/utils/InteractorDiagnosticPlot.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class InteractorDiagnosticPlot:
    def __init__(
        self,
        df: pd.DataFrame,
        quantile: float,
        B: NDArray[np.float_],  # Accept a tuple of three coefficients
        model_type: Literal["linear", "sigmoid"] = "linear",
        font_size: int = 18,
        line_thickness: float = 3,
        legend_font_size: int = 12,
        axis_tick_font_size: int = 14,
        left_asymptote: float | None = None,
        right_asymptote: float | None = None,
    ):
        """Initialize the class with data and model parameters."""
        self.df = df
        self.quantile = quantile
        self.B = B  # Expect B as a tuple
        self.model_type = model_type
        self.font_size = font_size
        self.line_thickness = line_thickness
        self.legend_font_size = legend_font_size
        self.axis_tick_font_size = axis_tick_font_size
        self.left_asymptote = left_asymptote
        self.right_asymptote = right_asymptote

        # Validate inputs when initializing
        self.validate_inputs()

    def validate_inputs(self):
        """Validates the inputs for the class."""
        if not isinstance(self.df, pd.DataFrame) or self.df.shape[1] != 3:
            raise ValueError("`df` must be a DataFrame with exactly three columns.")
        if not (0 < self.quantile < 1):
            raise ValueError("Quantile must be between 0 and 1.")
        if len(self.B) != 3 or not all(isinstance(b, (float)) for b in self.B):
            raise ValueError("B must be a tuple with exactly three float values.")
        if self.model_type not in ["linear", "sigmoid"]:
            raise ValueError("model_type must be either 'linear' or 'sigmoid'.")
        if self.model_type == "sigmoid":
            if self.left_asymptote is None or self.right_asymptote is None:
                raise ValueError(
                    "You must provide left_asymptote and "
                    "right_asymptote for sigmoid model."
                )

    def compute_ecdf(self):
        """Computes the ECDF for the third column and splits the data into bottom and
        top quantiles."""
        self.df["ecdf"] = self.df.iloc[:, 2].rank(method="max") / len(self.df)
        self.df_bottom = self.df[self.df["ecdf"] <= self.quantile].copy()
        self.df_top = self.df[self.df["ecdf"] >= (1 - self.quantile)].copy()

        # Remove rows where X1 > 2.0 for cleaner plotting
        self.df_bottom = self.df_bottom[self.df_bottom.iloc[:, 1] <= 2.0]
        self.df_top = self.df_top[self.df_top.iloc[:, 1] <= 2.0]

    def calculate_means(self):
        """Calculates the mean of the interactor (X2) for bottom and top quantiles."""
        self.bottom_x2_mean = self.df_bottom.iloc[:, 2].mean()
        self.top_x2_mean = self.df_top.iloc[:, 2].mean()

    def add_jitter(self, data: np.ndarray, jitter_strength: float = 0.05) -> np.ndarray:
        """Adds jitter to a dataset to display point density."""
        return data + np.random.uniform(-jitter_strength, jitter_strength, len(data))

    def create_model_lines(self, **kwargs):
        """Creates model prediction lines for bottom and top quantiles."""
        self.x_vals = np.linspace(0, 2.0, 100)

        if self.model_type == "linear":
            # Linear model: B_0 + B_1 * x + B_2 * x * X2_mean
            self.bottom_line = (
                self.B[0]
                + self.B[1] * self.x_vals
                + self.B[2] * self.bottom_x2_mean * self.x_vals
            )
            self.top_line = (
                self.B[0]
                + self.B[1] * self.x_vals
                + self.B[2] * self.top_x2_mean * self.x_vals
            )

        elif self.model_type == "sigmoid":
            assert self.left_asymptote is not None and self.right_asymptote is not None

            x_lower = np.column_stack(
                (
                    np.ones_like(self.x_vals),
                    self.x_vals,
                    self.x_vals * self.bottom_x2_mean,
                )
            )
            x_upper = np.column_stack(
                (np.ones_like(self.x_vals), self.x_vals, self.x_vals * self.top_x2_mean)
            )

            # Compute the bottom and top model lines using sigmoid

            self.bottom_line = sigmoid(
                X=x_lower,
                left_asymptote=self.left_asymptote,
                right_asymptote=self.right_asymptote,
                B=self.B,
            )
            self.top_line = sigmoid(
                X=x_upper,
                left_asymptote=self.left_asymptote,
                right_asymptote=self.right_asymptote,
                B=self.B,
            )

        else:
            raise ValueError("Unsupported model_type. Use 'linear' or 'sigmoid'.")

    def create_plot(self, ax, x_data, y_data, line_data, color, label, mean_value):
        """Creates a scatter plot with a fitted line for a given quantile."""
        ax.scatter(x_data, y_data, color=color, alpha=0.5, label=label)
        ax.plot(
            self.x_vals,
            line_data,
            color=color,
            label=f"Model (Mean = {mean_value:.2f})",
            linewidth=self.line_thickness,
        )
        ax.axhline(0, color="black", linestyle="--", linewidth=self.line_thickness)
        ax.tick_params(axis="both", which="major", labelsize=self.axis_tick_font_size)
        ax.legend(fontsize=self.legend_font_size, loc="upper left", frameon=False)
        ax.set_xlabel("Perturbed TF binding strength", fontsize=self.font_size)
        ax.set_ylabel("Perturbed TF perturbation response", fontsize=self.font_size)

    def plot(self, **kwargs):
        """Main plotting method that returns the `plt` object for further
        customization."""
        if self.model_type == "sigmoid":
            for k in ["left_asymptote", "right_asymptote", "B"]:
                if k in kwargs:
                    raise ValueError(
                        f"You must provide {k} as a keyword argument for sigmoid model."
                    )
        # Compute ECDF and split data
        self.compute_ecdf()

        # Calculate means of interactor (X2) for both quantiles
        self.calculate_means()

        # Create model lines for the plots
        self.create_model_lines(**kwargs)

        # Add jitter to Y values
        self.df_bottom.iloc[:, 0] = self.add_jitter(self.df_bottom.iloc[:, 0])
        self.df_top.iloc[:, 0] = self.add_jitter(self.df_top.iloc[:, 0])

        # Create subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

        # Plot bottom quantile
        self.create_plot(
            ax1,
            self.df_bottom.iloc[:, 1],
            self.df_bottom.iloc[:, 0],
            self.bottom_line,
            "blue",
            "Bottom 10% by Met31 binding",
            self.bottom_x2_mean,
        )

        # Plot top quantile
        self.create_plot(
            ax2,
            self.df_top.iloc[:, 1],
            self.df_top.iloc[:, 0],
            self.top_line,
            "red",
            "Top 10% by Met31 binding",
            self.top_x2_mean,
        )

        # Return plt object for further customization
        return plt

    def __call__(self, **kwargs):
        """Allows the object to be called like a function, invoking the `plot`
        method."""
        return self.plot(**kwargs)

__call__(**kwargs)

Allows the object to be called like a function, invoking the plot method.

Source code in yeastdnnexplorer/utils/InteractorDiagnosticPlot.py
192
193
194
195
def __call__(self, **kwargs):
    """Allows the object to be called like a function, invoking the `plot`
    method."""
    return self.plot(**kwargs)

__init__(df, quantile, B, model_type='linear', font_size=18, line_thickness=3, legend_font_size=12, axis_tick_font_size=14, left_asymptote=None, right_asymptote=None)

Initialize the class with data and model parameters.

Source code in yeastdnnexplorer/utils/InteractorDiagnosticPlot.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def __init__(
    self,
    df: pd.DataFrame,
    quantile: float,
    B: NDArray[np.float_],  # Accept a tuple of three coefficients
    model_type: Literal["linear", "sigmoid"] = "linear",
    font_size: int = 18,
    line_thickness: float = 3,
    legend_font_size: int = 12,
    axis_tick_font_size: int = 14,
    left_asymptote: float | None = None,
    right_asymptote: float | None = None,
):
    """Initialize the class with data and model parameters."""
    self.df = df
    self.quantile = quantile
    self.B = B  # Expect B as a tuple
    self.model_type = model_type
    self.font_size = font_size
    self.line_thickness = line_thickness
    self.legend_font_size = legend_font_size
    self.axis_tick_font_size = axis_tick_font_size
    self.left_asymptote = left_asymptote
    self.right_asymptote = right_asymptote

    # Validate inputs when initializing
    self.validate_inputs()

add_jitter(data, jitter_strength=0.05)

Adds jitter to a dataset to display point density.

Source code in yeastdnnexplorer/utils/InteractorDiagnosticPlot.py
73
74
75
def add_jitter(self, data: np.ndarray, jitter_strength: float = 0.05) -> np.ndarray:
    """Adds jitter to a dataset to display point density."""
    return data + np.random.uniform(-jitter_strength, jitter_strength, len(data))

calculate_means()

Calculates the mean of the interactor (X2) for bottom and top quantiles.

Source code in yeastdnnexplorer/utils/InteractorDiagnosticPlot.py
68
69
70
71
def calculate_means(self):
    """Calculates the mean of the interactor (X2) for bottom and top quantiles."""
    self.bottom_x2_mean = self.df_bottom.iloc[:, 2].mean()
    self.top_x2_mean = self.df_top.iloc[:, 2].mean()

compute_ecdf()

Computes the ECDF for the third column and splits the data into bottom and top quantiles.

Source code in yeastdnnexplorer/utils/InteractorDiagnosticPlot.py
57
58
59
60
61
62
63
64
65
66
def compute_ecdf(self):
    """Computes the ECDF for the third column and splits the data into bottom and
    top quantiles."""
    self.df["ecdf"] = self.df.iloc[:, 2].rank(method="max") / len(self.df)
    self.df_bottom = self.df[self.df["ecdf"] <= self.quantile].copy()
    self.df_top = self.df[self.df["ecdf"] >= (1 - self.quantile)].copy()

    # Remove rows where X1 > 2.0 for cleaner plotting
    self.df_bottom = self.df_bottom[self.df_bottom.iloc[:, 1] <= 2.0]
    self.df_top = self.df_top[self.df_top.iloc[:, 1] <= 2.0]

create_model_lines(**kwargs)

Creates model prediction lines for bottom and top quantiles.

Source code in yeastdnnexplorer/utils/InteractorDiagnosticPlot.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def create_model_lines(self, **kwargs):
    """Creates model prediction lines for bottom and top quantiles."""
    self.x_vals = np.linspace(0, 2.0, 100)

    if self.model_type == "linear":
        # Linear model: B_0 + B_1 * x + B_2 * x * X2_mean
        self.bottom_line = (
            self.B[0]
            + self.B[1] * self.x_vals
            + self.B[2] * self.bottom_x2_mean * self.x_vals
        )
        self.top_line = (
            self.B[0]
            + self.B[1] * self.x_vals
            + self.B[2] * self.top_x2_mean * self.x_vals
        )

    elif self.model_type == "sigmoid":
        assert self.left_asymptote is not None and self.right_asymptote is not None

        x_lower = np.column_stack(
            (
                np.ones_like(self.x_vals),
                self.x_vals,
                self.x_vals * self.bottom_x2_mean,
            )
        )
        x_upper = np.column_stack(
            (np.ones_like(self.x_vals), self.x_vals, self.x_vals * self.top_x2_mean)
        )

        # Compute the bottom and top model lines using sigmoid

        self.bottom_line = sigmoid(
            X=x_lower,
            left_asymptote=self.left_asymptote,
            right_asymptote=self.right_asymptote,
            B=self.B,
        )
        self.top_line = sigmoid(
            X=x_upper,
            left_asymptote=self.left_asymptote,
            right_asymptote=self.right_asymptote,
            B=self.B,
        )

    else:
        raise ValueError("Unsupported model_type. Use 'linear' or 'sigmoid'.")

create_plot(ax, x_data, y_data, line_data, color, label, mean_value)

Creates a scatter plot with a fitted line for a given quantile.

Source code in yeastdnnexplorer/utils/InteractorDiagnosticPlot.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def create_plot(self, ax, x_data, y_data, line_data, color, label, mean_value):
    """Creates a scatter plot with a fitted line for a given quantile."""
    ax.scatter(x_data, y_data, color=color, alpha=0.5, label=label)
    ax.plot(
        self.x_vals,
        line_data,
        color=color,
        label=f"Model (Mean = {mean_value:.2f})",
        linewidth=self.line_thickness,
    )
    ax.axhline(0, color="black", linestyle="--", linewidth=self.line_thickness)
    ax.tick_params(axis="both", which="major", labelsize=self.axis_tick_font_size)
    ax.legend(fontsize=self.legend_font_size, loc="upper left", frameon=False)
    ax.set_xlabel("Perturbed TF binding strength", fontsize=self.font_size)
    ax.set_ylabel("Perturbed TF perturbation response", fontsize=self.font_size)

plot(**kwargs)

Main plotting method that returns the plt object for further customization.

Source code in yeastdnnexplorer/utils/InteractorDiagnosticPlot.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def plot(self, **kwargs):
    """Main plotting method that returns the `plt` object for further
    customization."""
    if self.model_type == "sigmoid":
        for k in ["left_asymptote", "right_asymptote", "B"]:
            if k in kwargs:
                raise ValueError(
                    f"You must provide {k} as a keyword argument for sigmoid model."
                )
    # Compute ECDF and split data
    self.compute_ecdf()

    # Calculate means of interactor (X2) for both quantiles
    self.calculate_means()

    # Create model lines for the plots
    self.create_model_lines(**kwargs)

    # Add jitter to Y values
    self.df_bottom.iloc[:, 0] = self.add_jitter(self.df_bottom.iloc[:, 0])
    self.df_top.iloc[:, 0] = self.add_jitter(self.df_top.iloc[:, 0])

    # Create subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    # Plot bottom quantile
    self.create_plot(
        ax1,
        self.df_bottom.iloc[:, 1],
        self.df_bottom.iloc[:, 0],
        self.bottom_line,
        "blue",
        "Bottom 10% by Met31 binding",
        self.bottom_x2_mean,
    )

    # Plot top quantile
    self.create_plot(
        ax2,
        self.df_top.iloc[:, 1],
        self.df_top.iloc[:, 0],
        self.top_line,
        "red",
        "Top 10% by Met31 binding",
        self.top_x2_mean,
    )

    # Return plt object for further customization
    return plt

validate_inputs()

Validates the inputs for the class.

Source code in yeastdnnexplorer/utils/InteractorDiagnosticPlot.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def validate_inputs(self):
    """Validates the inputs for the class."""
    if not isinstance(self.df, pd.DataFrame) or self.df.shape[1] != 3:
        raise ValueError("`df` must be a DataFrame with exactly three columns.")
    if not (0 < self.quantile < 1):
        raise ValueError("Quantile must be between 0 and 1.")
    if len(self.B) != 3 or not all(isinstance(b, (float)) for b in self.B):
        raise ValueError("B must be a tuple with exactly three float values.")
    if self.model_type not in ["linear", "sigmoid"]:
        raise ValueError("model_type must be either 'linear' or 'sigmoid'.")
    if self.model_type == "sigmoid":
        if self.left_asymptote is None or self.right_asymptote is None:
            raise ValueError(
                "You must provide left_asymptote and "
                "right_asymptote for sigmoid model."
            )