Source code for luna.analysis.residues

import math
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from luna.mol.entry import MolFileEntry


[docs]def generate_residue_matrix(interactions_mngrs, by_interaction=True): """Generate a matrix to count interactions per residue. Parameters ---------- interactions_mngrs : iterable of :class:`~luna.interaction.calc.InteractionsManager` A sequence of :class:`~luna.interaction.calc.InteractionsManager` objects from where interactions will be recovered. by_interaction : bool If True (the default), count the number of each interaction type per residue. Otherwise, count the overall number of interactions per residue. Returns ------- : :class:`pandas.DataFrame` """ data_by_entry = defaultdict(lambda: defaultdict(int)) residues = set() for inter_mngr in interactions_mngrs: entry = inter_mngr.entry for inter in inter_mngr: # Continue if no target is in the interaction. if not inter.src_grp.has_target() and not inter.trgt_grp.has_target(): continue # Ignore interactions involving the same compounds. if inter.src_grp.compounds == inter.trgt_grp.compounds: continue if inter.src_grp.has_hetatm(): comp1 = sorted(inter.src_grp.compounds) comp2 = sorted(inter.trgt_grp.compounds) elif inter.trgt_grp.has_hetatm(): comp1 = sorted(inter.trgt_grp.compounds) comp2 = sorted(inter.src_grp.compounds) else: comp1 = sorted(inter.src_grp.compounds) comp2 = sorted(inter.trgt_grp.compounds) comp1, comp2 = sorted([comp1, comp2]) comp1 = ";".join(["%s/%s/%d%s" % (r.parent.id, r.resname, r.id[1], r.id[2].strip()) for r in comp1]) comp2 = ";".join(["%s/%s/%d%s" % (r.parent.id, r.resname, r.id[1], r.id[2].strip()) for r in comp2]) entry_id = entry.mol_id if isinstance(entry, MolFileEntry) else entry.to_string() if by_interaction: key = (entry_id, inter.type) else: key = entry_id data_by_entry[key][comp2] += 1 residues.add(comp2) heatmap_data = defaultdict(list) if by_interaction: entries = set([k[0] for k in data_by_entry.keys()]) interactions = set([k[1] for k in data_by_entry.keys()]) for e in entries: for i in interactions: for res in residues: heatmap_data["entry"].append(e) heatmap_data["interaction"].append(i) heatmap_data["residues"].append(res) heatmap_data["frequency"].append(data_by_entry[(e, i)][res]) else: for key in data_by_entry: for res in data_by_entry[key]: heatmap_data["entry"].append(key) heatmap_data["residues"].append(res) heatmap_data["frequency"].append(data_by_entry[key][res]) df = pd.DataFrame.from_dict(heatmap_data) if by_interaction: return pd.pivot_table(df, index=['entry', 'interaction'], columns='residues', values='frequency', fill_value=0) else: return pd.pivot_table(df, index='entry', columns='residues', values='frequency', fill_value=0)
[docs]def heatmap(data_df, figsize=None, cmap="Blues", heatmap_kw=None, gridspec_kw=None): """ Plot a residue matrix as a color-encoded matrix. Parameters ---------- data_df : :class:`pandas.DataFrame` A residue matrix produced with :func:`~luna.analysis.residues.generate_residue_matrix`. figsize : tuple, optional Size (width, height) of a figure in inches. cmap : str, iterable of str The mapping from data values to color space. The default value is 'Blues'. heatmap_kw : dict, optional Keyword arguments for :func:`seaborn.heatmap`. gridspec_kw : dict, optional Keyword arguments for :class:`matplotlib.gridspec.GridSpec`. Used only if the residue matrix (``data_df``) contains interactions. Returns ------- : :class:`matplotlib.axes.Axes` or :class:`numpy.ndarray` of :class:`matplotlib.axes.Axes` """ data_df = data_df.reset_index() heatmap_kw = heatmap_kw or {} gridspec_kw = gridspec_kw or {} interactions = None if "interaction" in data_df.columns: interactions = sorted(data_df["interaction"].unique()) max_value = data_df[data_df.columns[2:]].max().max() else: max_value = data_df[data_df.columns[1:]].max().max() if not interactions: data_df.set_index('entry', inplace=True) fig = plt.figure(figsize=figsize) ax = sns.heatmap(data_df, cmap=cmap, vmax=max_value, vmin=0, **heatmap_kw) ax.set_xlabel("") ax.set_ylabel("") return ax else: ncols = 3 if "ncols" in gridspec_kw: ncols = gridspec_kw["ncols"] del gridspec_kw["ncols"] nrows = math.ceil(len(interactions) / ncols) fig, axs = plt.subplots(nrows, ncols, figsize=figsize, gridspec_kw=gridspec_kw) row, col = 0, 0 for i, interaction in enumerate(interactions): df = data_df[data_df["interaction"] == interaction].copy() df.drop(columns="interaction", inplace=True) df.set_index('entry', inplace=True) g = sns.heatmap(df, cmap=cmap, vmax=max_value, vmin=0, ax=axs[row][col], **heatmap_kw) g.set_title(interaction) g.set_xlabel("") g.set_ylabel("") col += 1 if col == ncols: row += 1 col = 0 if len(interactions) < nrows * ncols: diff = (nrows * ncols) - len(interactions) for i in range(1, diff + 1): axs[-1][-1 * i].axis('off') return axs