Skip to content

Fine-Mapping API

Pipeline and Fine-Mapping

Main module.

fine_map(locus_set, tool='susie', max_causal=5, adaptive_max_causal=False, set_L_by_cojo=True, p_cutoff=5e-08, collinear_cutoff=0.9, window_size=10000000, maf_cutoff=0.01, diff_freq_cutoff=0.2, combine_cred='union', combine_pip='max', jaccard_threshold=0.1, timeout_minutes=None, strategy=None, significant_threshold=5e-08, **kwargs)

Perform fine-mapping on a locus set.

Parameters:

Name Type Description Default
locus_set LocusSet

Locus set to fine-mapping.

required
tool str

Fine-mapping tool. Choose from ["abf", "abf_cojo", "finemap", "rsparsepro", "susie", "multisusie", "susiex", "mesusie"] - Single-input tools (abf, abf_cojo, finemap, rsparsepro, susie): Process each locus individually - Multi-input tools (multisusie, susiex, mesusie): Process all loci together When using single-input tools with multiple loci, results are automatically combined

'susie'
combine_cred str

Method to combine credible sets, by default "union". Options: "union", "intersection", "cluster". "union": Union of all credible sets to form a merged credible set. "intersection": Frist merge the credible sets from the same tool, then take the intersection of all merged credible sets. no credible set will be returned if no common SNPs found. "cluster": Merge credible sets with Jaccard index > 0.1.

'union'
combine_pip str

Method to combine PIPs, by default "max". Options: "max", "min", "mean", "meta". "meta": PIP_meta = 1 - prod(1 - PIP_i), where i is the index of tools, PIP_i = 0 when the SNP is not in the credible set of the tool. "max": Maximum PIP value for each SNP across all tools. "min": Minimum PIP value for each SNP across all tools. "mean": Mean PIP value for each SNP across all tools.

'max'
jaccard_threshold float

Jaccard index threshold for the "cluster" method, by default 0.1.

0.1
timeout_minutes Optional[float]

Maximum runtime per locus in minutes when running the FINEMAP tool. Defaults to 30 minutes for FINEMAP. Ignored for other tools.

None
max_causal int

Maximum number of causal variants, by default 5.

5
adaptive_max_causal bool

Enable adaptive max_causal parameter tuning, by default False. When True, automatically adjusts max_causal based on results: - If credible sets >= max_causal, increase by 5 (up to 20) - If convergence fails, decrease by 1 (down to 1) Applies to: finemap, susie, rsparsepro (per-locus), multisusie, susiex (LocusSet-level).

False
strategy str

DEPRECATED. This parameter is no longer used and will be removed in a future version. The strategy is now automatically determined based on the tool and data structure.

None
significant_threshold float

Minimum p-value required for variants to be considered significant. If no variants pass this threshold, single-input tools return empty credible sets with zero posterior probabilities. Defaults to 5e-8.

5e-08
Source code in credtools/credtools.py
def fine_map(
    locus_set: LocusSet,
    tool: str = "susie",
    max_causal: int = 5,
    adaptive_max_causal: bool = False,
    set_L_by_cojo: bool = True,
    p_cutoff: float = 5e-8,
    collinear_cutoff: float = 0.9,
    window_size: int = 10000000,
    maf_cutoff: float = 0.01,
    diff_freq_cutoff: float = 0.2,
    combine_cred: str = "union",
    combine_pip: str = "max",
    jaccard_threshold: float = 0.1,
    timeout_minutes: Optional[float] = None,
    strategy: Optional[str] = None,  # Deprecated parameter
    significant_threshold: float = 5e-8,
    **kwargs,
) -> CredibleSet:
    """
    Perform fine-mapping on a locus set.

    Parameters
    ----------
    locus_set : LocusSet
        Locus set to fine-mapping.
    tool : str
        Fine-mapping tool. Choose from ["abf", "abf_cojo", "finemap", "rsparsepro", "susie", "multisusie", "susiex", "mesusie"]
        - Single-input tools (abf, abf_cojo, finemap, rsparsepro, susie): Process each locus individually
        - Multi-input tools (multisusie, susiex, mesusie): Process all loci together
        When using single-input tools with multiple loci, results are automatically combined
    combine_cred : str, optional
        Method to combine credible sets, by default "union".
        Options: "union", "intersection", "cluster".
        "union":        Union of all credible sets to form a merged credible set.
        "intersection": Frist merge the credible sets from the same tool,
                        then take the intersection of all merged credible sets.
                        no credible set will be returned if no common SNPs found.
        "cluster":      Merge credible sets with Jaccard index > 0.1.
    combine_pip : str, optional
        Method to combine PIPs, by default "max".
        Options: "max", "min", "mean", "meta".
        "meta": PIP_meta = 1 - prod(1 - PIP_i), where i is the index of tools,
                PIP_i = 0 when the SNP is not in the credible set of the tool.
        "max":  Maximum PIP value for each SNP across all tools.
        "min":  Minimum PIP value for each SNP across all tools.
        "mean": Mean PIP value for each SNP across all tools.
    jaccard_threshold : float, optional
        Jaccard index threshold for the "cluster" method, by default 0.1.
    timeout_minutes : Optional[float], optional
        Maximum runtime per locus in minutes when running the FINEMAP tool. Defaults to 30 minutes for FINEMAP.
        Ignored for other tools.
    max_causal : int, optional
        Maximum number of causal variants, by default 5.
    adaptive_max_causal : bool, optional
        Enable adaptive max_causal parameter tuning, by default False.
        When True, automatically adjusts max_causal based on results:
        - If credible sets >= max_causal, increase by 5 (up to 20)
        - If convergence fails, decrease by 1 (down to 1)
        Applies to: finemap, susie, rsparsepro (per-locus), multisusie, susiex (LocusSet-level).
    strategy : str, optional
        DEPRECATED. This parameter is no longer used and will be removed in a future version.
        The strategy is now automatically determined based on the tool and data structure.
    significant_threshold : float, optional
        Minimum p-value required for variants to be considered significant. If no variants
        pass this threshold, single-input tools return empty credible sets with zero posterior
        probabilities. Defaults to 5e-8.
    """
    # Deprecation warning for strategy parameter
    if strategy is not None:
        import warnings

        warnings.warn(
            "The 'strategy' parameter is deprecated and will be removed in a future version. "
            "The strategy is now automatically determined based on the tool and data structure.",
            DeprecationWarning,
            stacklevel=2,
        )

    kwargs.setdefault("significant_threshold", significant_threshold)

    # When adaptive_max_causal is enabled, default empty_on_nonconvergence=True so
    # non-converged runs return n_cs=0 and naturally trigger Phase-2 L decrement.
    # User-provided empty_on_nonconvergence overrides this default.
    if adaptive_max_causal:
        kwargs.setdefault("empty_on_nonconvergence", True)

    # Extract purity parameter for centralized filtering
    purity_threshold = kwargs.get("purity", 0.0)

    # Handle timeout defaults for FINEMAP
    if timeout_minutes is None and tool == "finemap":
        timeout_minutes = 30.0
    if timeout_minutes is not None:
        timeout_minutes = float(timeout_minutes)
        if timeout_minutes <= 0:
            raise ValueError("timeout_minutes must be a positive value.")
        kwargs["timeout_minutes"] = timeout_minutes

    # Define tool categories
    single_input_tools = [
        "abf",
        "abf_cojo",
        "carma",
        "finemap",
        "rsparsepro",
        "susie",
        "susie_ash",
        "susie_inf",
    ]
    multi_input_tools = ["multisusie", "susiex", "mesusie"]

    # Tool function mapping
    tool_func_dict: Dict[str, Callable[..., Any]] = {
        "abf": run_abf,
        "abf_cojo": run_abf_cojo,
        "carma": run_carma,
        "finemap": run_finemap,
        "rsparsepro": run_rsparsepro,
        "susie": run_susie,
        "susie_ash": run_susie_ash,
        "susie_inf": run_susie_inf,
        "multisusie": run_multisusie,
        "susiex": run_susiex,
        "mesusie": run_mesusie,
    }

    # Get tool-specific parameters
    inspect_dict = {
        "abf": set(inspect.signature(run_abf).parameters),
        "abf_cojo": set(inspect.signature(run_abf_cojo).parameters),
        "carma": set(inspect.signature(run_carma).parameters),
        "finemap": set(inspect.signature(run_finemap).parameters),
        "rsparsepro": set(inspect.signature(run_rsparsepro).parameters),
        "susie": set(inspect.signature(run_susie).parameters),
        "susie_ash": set(inspect.signature(run_susie_ash).parameters),
        "susie_inf": set(inspect.signature(run_susie_inf).parameters),
        "multisusie": set(inspect.signature(run_multisusie).parameters),
        "susiex": set(inspect.signature(run_susiex).parameters),
        "mesusie": set(inspect.signature(run_mesusie).parameters),
    }
    params_dict = {}
    for t, args in inspect_dict.items():
        params_dict[t] = {k: v for k, v in kwargs.items() if k in args}

    # Automatic strategy selection based on tool type
    if tool in multi_input_tools:
        # Multi-input tools: directly process the entire LocusSet
        logger.info(f"Using multi-input tool {tool} to process {locus_set.n_loci} loci")

        # Use adaptive logic if enabled
        if adaptive_max_causal:
            combined = _adaptive_fine_map_multi(
                locus_set,
                tool,
                max_causal,
                tool_func_dict[tool],
                params_dict[tool],
                purity_threshold=purity_threshold,
            )
        else:
            combined = tool_func_dict[tool](
                locus_set, max_causal=max_causal, **params_dict[tool]
            )

        # Apply purity filtering if requested
        if purity_threshold > 0:
            combined = filter_credset_by_purity(combined, min_purity=purity_threshold)

        combined.set_per_locus_results({})
        return combined

    elif tool in single_input_tools:
        if locus_set.n_loci == 1:
            # Single locus: direct analysis
            logger.info(f"Using single-input tool {tool} for single locus")
            locus = locus_set.loci[0]

            # COJO analysis for max_causal if enabled (skip for abf_cojo as it handles its own)
            if set_L_by_cojo and tool != "abf_cojo":
                max_causal = _determine_max_causal_by_cojo(
                    locus,
                    p_cutoff,
                    collinear_cutoff,
                    window_size,
                    maf_cutoff,
                    diff_freq_cutoff,
                )

            # Use adaptive logic if enabled
            if adaptive_max_causal and tool in [
                "finemap",
                "susie",
                "susie_ash",
                "susie_inf",
                "rsparsepro",
            ]:
                result = _adaptive_fine_map(
                    locus,
                    tool,
                    max_causal,
                    tool_func_dict[tool],
                    params_dict[tool],
                    purity_threshold=purity_threshold,
                )
            else:
                result = tool_func_dict[tool](
                    locus, max_causal=max_causal, **params_dict[tool]
                )

            # Apply purity filtering if requested
            if purity_threshold > 0:
                result = filter_credset_by_purity(result, min_purity=purity_threshold)

            # Set per-locus results and return
            locus_id = getattr(locus, "locus_id", getattr(locus, "name", "locus_0"))
            if hasattr(result, "copy"):
                result_copy = result.copy()
            else:
                result_copy = deepcopy(result)
            per_locus_results = {locus_id: result_copy}
            if hasattr(result, "set_per_locus_results"):
                result.set_per_locus_results(per_locus_results)
            return result

        else:
            # Multiple loci: analyze each and combine results
            logger.info(
                f"Using single-input tool {tool} for {locus_set.n_loci} loci, "
                f"will combine results using combine_cred={combine_cred}, combine_pip={combine_pip}"
            )
            all_creds = []
            for i, locus in enumerate(locus_set.loci):
                logger.info(f"Processing locus {i+1}/{locus_set.n_loci}")

                # Optionally apply COJO for each locus
                locus_max_causal = max_causal
                if set_L_by_cojo and tool != "abf_cojo":
                    locus_max_causal = _determine_max_causal_by_cojo(
                        locus,
                        p_cutoff,
                        collinear_cutoff,
                        window_size,
                        maf_cutoff,
                        diff_freq_cutoff,
                        locus_index=i + 1,
                    )

                # Run fine-mapping for this locus
                if adaptive_max_causal and tool in [
                    "finemap",
                    "susie",
                    "susie_ash",
                    "susie_inf",
                    "rsparsepro",
                ]:
                    creds = _adaptive_fine_map(
                        locus,
                        tool,
                        locus_max_causal,
                        tool_func_dict[tool],
                        params_dict[tool],
                        purity_threshold=purity_threshold,
                    )
                else:
                    creds = tool_func_dict[tool](
                        locus, max_causal=locus_max_causal, **params_dict[tool]
                    )
                all_creds.append(creds)

            # Combine results
            logger.info("Combining credible sets from all loci")
            # Collect LD matrices for purity calculation
            ld_list = [locus.ld for locus in locus_set.loci if locus.ld is not None]
            combined = combine_creds(
                all_creds,
                combine_cred=combine_cred,
                combine_pip=combine_pip,
                jaccard_threshold=jaccard_threshold,
                ld_matrices=ld_list,
                min_purity=purity_threshold,
            )
            per_locus_results: Dict[str, CredibleSet] = {}
            for locus, cred in zip(locus_set.loci, all_creds):
                cred_copy = cred.copy()
                if purity_threshold > 0:
                    cred_copy = filter_credset_by_purity(
                        cred_copy, min_purity=purity_threshold
                    )
                per_locus_results[locus.locus_id] = cred_copy
            combined.set_per_locus_results(per_locus_results)
            return combined

    else:
        raise ValueError(
            f"Tool {tool} not recognized. Available tools: {list(tool_func_dict.keys())}"
        )

pipeline(loci_df, meta_method='meta_all', skip_qc=False, tool='susie', outdir='.', calculate_lambda_s=False, strategy=None, **kwargs)

Run whole fine-mapping pipeline on a list of loci.

Parameters:

Name Type Description Default
loci_df DataFrame

Dataframe containing the locus information.

required
meta_method str

Meta-analysis method, by default "meta_all" Options: "meta_all", "meta_by_population", "no_meta".

'meta_all'
skip_qc bool

Skip QC, by default False.

False
tool str

Fine-mapping tool, by default "susie".

'susie'
calculate_lambda_s bool

Whether to calculate lambda_s parameter using estimate_s_rss function, by default False.

False
strategy str

DEPRECATED. This parameter is no longer used and will be removed in a future version.

None
Source code in credtools/credtools.py
def pipeline(
    loci_df: pd.DataFrame,
    meta_method: str = "meta_all",
    skip_qc: bool = False,
    tool: str = "susie",
    outdir: str = ".",
    calculate_lambda_s: bool = False,
    strategy: Optional[str] = None,  # Deprecated parameter
    **kwargs,
):
    """
    Run whole fine-mapping pipeline on a list of loci.

    Parameters
    ----------
    loci_df : pd.DataFrame
        Dataframe containing the locus information.
    meta_method : str, optional
        Meta-analysis method, by default "meta_all"
        Options: "meta_all", "meta_by_population", "no_meta".
    skip_qc : bool, optional
        Skip QC, by default False.
    tool : str, optional
        Fine-mapping tool, by default "susie".
    calculate_lambda_s : bool, optional
        Whether to calculate lambda_s parameter using estimate_s_rss function, by default False.
    strategy : str, optional
        DEPRECATED. This parameter is no longer used and will be removed in a future version.
    """
    import sys
    from datetime import datetime

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    # Initialize run summary
    run_summary = {
        "start_time": datetime.now().isoformat(),
        "total_loci": 0,
        "successful_loci": 0,
        "failed_loci": 0,
        "errors": [],
        "tool": tool,
        "meta_method": meta_method,
        "parameters": kwargs,
    }

    # Collect all credible sets for summary
    all_credible_sets = []

    try:
        locus_set = load_locus_set(loci_df, calculate_lambda_s=calculate_lambda_s)
        run_summary["total_loci"] = locus_set.n_loci

        # Compute heterogeneity BEFORE meta combines data
        if meta_method == "meta_by_population":
            het_metrics = compute_heterogeneity_by_population(locus_set)
        else:
            het_metrics = compute_heterogeneity(locus_set)
        het_summary = heterogeneity_summary(het_metrics, locus_set)
        save_heterogeneity(het_metrics, outdir, summary=het_summary)
        logger.info("Heterogeneity metrics computed and saved.")

        # meta-analysis
        locus_set = meta(locus_set, meta_method=meta_method)
        logger.info(f"Meta-analysis complete, {locus_set.n_loci} loci loaded.")
        logger.info(f"Save meta-analysis results to {outdir}.")

        for locus in locus_set.loci:
            out_prefix = f"{outdir}/{locus.prefix}"
            locus.sumstats.to_csv(f"{out_prefix}.sumstat", sep="\t", index=False)
            np.savez_compressed(
                f"{out_prefix}.ld.npz", ld=locus.ld.r.astype(np.float16)
            )
            locus.ld.map.to_csv(f"{out_prefix}.ldmap", sep="\t", index=False)

        # QC
        if not skip_qc:
            qc_metrics = locus_qc(locus_set)
            logger.info(f"QC complete, {qc_metrics.keys()} metrics saved.")
            for k, v in qc_metrics.items():
                v.to_csv(
                    f"{outdir}/{k}.txt", sep="\t", index=False, float_format="%.6f"
                )

        # fine-mapping
        try:
            creds = fine_map(locus_set, tool=tool, strategy=strategy, **kwargs)
            run_summary["successful_loci"] = locus_set.n_loci

            # Create enhanced PIPs DataFrame
            enhanced_pips = creds.create_enhanced_pips_df(locus_set)

            # Extract causal variants and create CS summary BEFORE formatting
            # (formatting converts PIP to strings which breaks numeric comparisons)
            locus_id = f"{locus_set.chrom}_{locus_set.start}_{locus_set.end}"
            causal_variants = enhanced_pips[enhanced_pips["CRED"] != 0].copy()
            if len(causal_variants) > 0:
                causal_variants["locus_id"] = locus_id
                all_credible_sets.append(causal_variants)

            # Create credible sets summary (one row per CS)
            from credtools.credibleset import generate_cs_summary

            cs_summary_list = generate_cs_summary(causal_variants, locus_id, locus_set)

            # Format enhanced PIPs for output
            from credtools.utils import format_enhanced_pips

            enhanced_pips = format_enhanced_pips(enhanced_pips)

            # Save formatted enhanced PIPs
            output_file = f"{outdir}/pips.txt.gz"
            enhanced_pips.to_csv(
                output_file,
                sep="\t",
                index=False,
                compression="gzip",
            )

            # Save causal variants
            if len(causal_variants) > 0:
                causal_variants.to_csv(
                    f"{outdir}/causal_variants.txt.gz",
                    sep="\t",
                    index=False,
                    compression="gzip",
                )

            if cs_summary_list:
                cs_summary_df = pd.DataFrame(cs_summary_list)
                cs_summary_df.to_csv(
                    f"{outdir}/credible_sets_summary.txt.gz",
                    sep="\t",
                    index=False,
                    compression="gzip",
                )

            # Save parameters (without lead_snps, snps, cs_sizes)
            parameters_dict = {
                "tool": creds.tool,
                "n_cs": creds.n_cs,
                "coverage": creds.coverage,
                "parameters": creds.parameters,
            }
            with open(f"{outdir}/parameters.json", "w") as f:
                json.dump(parameters_dict, f, indent=4)

            logger.info(f"Fine-mapping complete, {creds.n_cs} credible sets saved.")

        except Exception as e:
            error_msg = f"Fine-mapping failed: {str(e)}"
            logger.error(error_msg)
            print(f"ERROR: {error_msg}", file=sys.stderr)
            run_summary["failed_loci"] = locus_set.n_loci
            run_summary["errors"].append(error_msg)

    except Exception as e:
        error_msg = f"Pipeline failed: {str(e)}"
        logger.error(error_msg)
        print(f"ERROR: {error_msg}", file=sys.stderr)
        run_summary["errors"].append(error_msg)

    finally:
        # Generate run summary
        run_summary["end_time"] = datetime.now().isoformat()
        _generate_run_summary(run_summary, f"{outdir}/run_summary.log")

    return

Credible Sets

Credible Set functions.

CredibleSet(tool, parameters, coverage, n_cs, cs_sizes, lead_snps, snps, pips, per_locus_results=None, purity=None, converged=None, n_iter=None)

Class representing credible sets from one fine-mapping tool.

Parameters:

Name Type Description Default
tool str

The name of the fine-mapping tool.

required
parameters Dict[str, Any]

Additional parameters used by the fine-mapping tool.

required
coverage float

The coverage of the credible sets.

required
n_cs int

The number of credible sets.

required
cs_sizes List[int]

Sizes of each credible set.

required
lead_snps List[str]

List of lead SNPs.

required
snps List[List[str]]

List of SNPs for each credible set.

required
pips Series

Posterior inclusion probabilities.

required

Attributes:

Name Type Description
tool str

The name of the fine-mapping tool.

n_cs int

The number of credible sets.

coverage float

The coverage of the credible sets.

lead_snps List[str]

List of lead SNPs.

snps List[List[str]]

List of SNPs for each credible set.

cs_sizes List[int]

Sizes of each credible set.

pips Series

Posterior inclusion probabilities.

parameters Dict[str, Any]

Additional parameters used by the fine-mapping tool.

Parameters:

Name Type Description Default
tool str

The name of the fine-mapping tool.

required
parameters Dict[str, Any]

Additional parameters used by the fine-mapping tool.

required
coverage float

The coverage of the credible sets.

required
n_cs int

The number of credible sets.

required
cs_sizes List[int]

Sizes of each credible set.

required
lead_snps List[str]

List of lead SNPs.

required
snps List[List[str]]

List of SNPs for each credible set.

required
pips Series

Posterior inclusion probabilities.

required
per_locus_results Optional[Dict[str, CredibleSet]]

Mapping of locus identifiers to their individual credible set results.

None
purity Optional[List[Optional[float]]]

List of purity values for each credible set. Purity is the minimum absolute LD R value between all SNP pairs in a credible set. None if LD matrix is not available.

None
converged Optional[bool]

Whether the underlying iterative algorithm converged. None when the producing tool is non-iterative or convergence is not tracked (e.g., ABF, FINEMAP).

None
n_iter Optional[int]

Number of iterations performed by the underlying algorithm. None when not tracked.

None
Source code in credtools/credibleset.py
def __init__(
    self,
    tool: str,
    parameters: Dict[str, Any],
    coverage: float,
    n_cs: int,
    cs_sizes: List[int],
    lead_snps: List[str],
    snps: List[List[str]],
    pips: pd.Series,
    per_locus_results: Optional[Dict[str, "CredibleSet"]] = None,
    purity: Optional[List[Optional[float]]] = None,
    converged: Optional[bool] = None,
    n_iter: Optional[int] = None,
) -> None:
    """
    Initialize CredibleSet object.

    Parameters
    ----------
    tool : str
        The name of the fine-mapping tool.
    parameters : Dict[str, Any]
        Additional parameters used by the fine-mapping tool.
    coverage : float
        The coverage of the credible sets.
    n_cs : int
        The number of credible sets.
    cs_sizes : List[int]
        Sizes of each credible set.
    lead_snps : List[str]
        List of lead SNPs.
    snps : List[List[str]]
        List of SNPs for each credible set.
    pips : pd.Series
        Posterior inclusion probabilities.
    per_locus_results : Optional[Dict[str, "CredibleSet"]], optional
        Mapping of locus identifiers to their individual credible set results.
    purity : Optional[List[Optional[float]]], optional
        List of purity values for each credible set. Purity is the minimum
        absolute LD R value between all SNP pairs in a credible set.
        None if LD matrix is not available.
    converged : Optional[bool], optional
        Whether the underlying iterative algorithm converged. None when
        the producing tool is non-iterative or convergence is not
        tracked (e.g., ABF, FINEMAP).
    n_iter : Optional[int], optional
        Number of iterations performed by the underlying algorithm.
        None when not tracked.
    """
    self._tool = tool
    self._parameters = parameters
    self._coverage = coverage
    self._n_cs = n_cs
    self._cs_sizes = cs_sizes
    self._lead_snps = lead_snps
    self._snps = snps
    self._pips = pips
    self._per_locus_results: Dict[str, "CredibleSet"] = per_locus_results or {}
    self._purity = purity
    self._converged = converged
    self._n_iter = n_iter

converged property

Get convergence status of the underlying iterative algorithm.

coverage property

Get the coverage.

cs_sizes property

Get the sizes of each credible set.

lead_snps property

Get the lead SNPs.

n_cs property

Get the number of credible sets.

n_iter property

Get the number of iterations performed by the underlying algorithm.

parameters property

Get the parameters.

per_locus_results property

Get per-locus credible set results.

pips property

Get the PIPs.

purity property

Get the purity values for each credible set.

snps property

Get the SNPs.

tool property

Get the tool name.

__repr__()

Return a string representation of the CredibleSet object.

Returns:

Type Description
str

String representation of the CredibleSet object.

Source code in credtools/credibleset.py
def __repr__(self) -> str:
    """
    Return a string representation of the CredibleSet object.

    Returns
    -------
    str
        String representation of the CredibleSet object.
    """
    return (
        f"CredibleSet(\n  tool={self.tool}, coverage={self.coverage}, n_cs={self.n_cs}, cs_sizes={self.cs_sizes}, lead_snps={self.lead_snps},"
        + f"\n  Parameters: {json.dumps(self.parameters)}\n)"
    )

copy()

Copy the CredibleSet object.

Returns:

Type Description
CredibleSet

A copy of the CredibleSet object.

Source code in credtools/credibleset.py
def copy(self) -> "CredibleSet":
    """
    Copy the CredibleSet object.

    Returns
    -------
    CredibleSet
        A copy of the CredibleSet object.
    """
    copied = CredibleSet(
        tool=self.tool,
        parameters=dict(self.parameters),
        coverage=self.coverage,
        n_cs=self.n_cs,
        cs_sizes=self.cs_sizes.copy(),
        lead_snps=self.lead_snps.copy(),
        snps=[list(snp) for snp in self.snps],
        pips=self.pips.copy(),
        purity=self.purity.copy() if self.purity is not None else None,
        converged=self.converged,
        n_iter=self.n_iter,
    )
    if self.per_locus_results:
        per_locus_copy = {}
        for key, value in self.per_locus_results.items():
            if value is self:
                per_locus_copy[key] = copied
            else:
                per_locus_copy[key] = value.copy()
        copied.set_per_locus_results(per_locus_copy)
    return copied

create_enhanced_pips_df(locus_set)

Create DataFrame with PIPs and full sumstats information.

Parameters:

Name Type Description Default
locus_set LocusSet

The locus set containing locus data.

required

Returns:

Type Description
DataFrame

DataFrame containing full sumstats, PIPs, R2, and credible set assignments.

Source code in credtools/credibleset.py
def create_enhanced_pips_df(self, locus_set) -> pd.DataFrame:
    """
    Create DataFrame with PIPs and full sumstats information.

    Parameters
    ----------
    locus_set : LocusSet
        The locus set containing locus data.

    Returns
    -------
    pd.DataFrame
        DataFrame containing full sumstats, PIPs, R2, and credible set assignments.
    """
    from credtools.constants import ColName
    from credtools.qc import intersect_sumstat_ld

    # Collect all unique SNPIDs from PIPs
    all_snpids = self.pips.index.tolist()

    # Initialize the result DataFrame with SNPIDs
    result_df = pd.DataFrame({ColName.SNPID: all_snpids})

    # Process based on number of loci
    if locus_set.n_loci == 1:
        # Single locus case - simpler column names
        locus = locus_set.loci[0]

        # Make sure we have matched sumstats and LD
        locus_copy = locus.copy()
        locus_copy = intersect_sumstat_ld(locus_copy)

        # Merge with sumstats
        sumstats_cols = [
            ColName.SNPID,
            ColName.CHR,
            ColName.BP,
            ColName.RSID,
            ColName.EA,
            ColName.NEA,
            ColName.EAF,
            ColName.MAF,
            ColName.BETA,
            ColName.SE,
            ColName.P,
        ]

        # Get available columns from sumstats
        available_cols = [
            col for col in sumstats_cols if col in locus_copy.sumstats.columns
        ]
        result_df = result_df.merge(
            locus_copy.sumstats[available_cols], on=ColName.SNPID, how="left"
        )

        # Calculate R2 (squared correlation with lead SNP)
        if locus_copy.ld is not None and len(locus_copy.sumstats) > 0:
            # Find lead SNP (lowest p-value)
            lead_idx = locus_copy.sumstats[ColName.P].idxmin()
            # Calculate R2 for all SNPs
            r2_values = locus_copy.ld.r[lead_idx] ** 2
            # Map R2 values to SNPIDs
            snpid_to_r2 = dict(zip(locus_copy.sumstats[ColName.SNPID], r2_values))
            result_df["R2"] = result_df[ColName.SNPID].map(snpid_to_r2)
        else:
            result_df["R2"] = np.nan

    else:
        # Multiple loci case - prefixed column names
        # First, add common columns that don't need prefix
        first_locus = locus_set.loci[0]
        common_cols = [
            ColName.CHR,
            ColName.BP,
            ColName.RSID,
            ColName.EA,
            ColName.NEA,
        ]
        available_common = [
            col for col in common_cols if col in first_locus.sumstats.columns
        ]

        # Use the first locus for common columns
        if available_common:
            result_df = result_df.merge(
                first_locus.sumstats[[ColName.SNPID] + available_common],
                on=ColName.SNPID,
                how="left",
            )

        # Add locus-specific columns with prefixes
        for locus in locus_set.loci:
            prefix = f"{locus.popu}_{locus.cohort}_"

            # Make sure we have matched sumstats and LD
            locus_copy = locus.copy()
            locus_copy = intersect_sumstat_ld(locus_copy)

            # Columns to add with prefix
            locus_cols = [
                ColName.EAF,
                ColName.MAF,
                ColName.BETA,
                ColName.SE,
                ColName.P,
            ]

            for col in locus_cols:
                if col in locus_copy.sumstats.columns:
                    col_data = locus_copy.sumstats[[ColName.SNPID, col]].copy()
                    col_data.rename(columns={col: f"{prefix}{col}"}, inplace=True)
                    result_df = result_df.merge(
                        col_data, on=ColName.SNPID, how="left"
                    )

            # Calculate R2
            if locus_copy.ld is not None and len(locus_copy.sumstats) > 0:
                lead_idx = locus_copy.sumstats[ColName.P].idxmin()
                r2_values = locus_copy.ld.r[lead_idx] ** 2
                snpid_to_r2 = dict(
                    zip(locus_copy.sumstats[ColName.SNPID], r2_values)
                )
                result_df[f"{prefix}R2"] = result_df[ColName.SNPID].map(snpid_to_r2)
            else:
                result_df[f"{prefix}R2"] = np.nan

            # Add per-locus PIP and CRED columns when available
            if self.per_locus_results:
                locus_creds = self.per_locus_results.get(locus.locus_id)
                if locus_creds is not None:
                    pip_col = f"{prefix}PIP"
                    result_df[pip_col] = (
                        result_df[ColName.SNPID]
                        .map(locus_creds.pips.to_dict())
                        .fillna(0.0)
                    )
                    cred_col = f"{prefix}CRED"
                    result_df[cred_col] = 0
                    for cs_idx, snp_list in enumerate(locus_creds.snps, 1):
                        mask = result_df[ColName.SNPID].isin(snp_list)
                        result_df.loc[mask, cred_col] = cs_idx

    # Add credible set assignments (CRED column)
    result_df["CRED"] = 0  # Default: not in any credible set
    for cs_idx, snp_list in enumerate(self.snps, 1):
        mask = result_df[ColName.SNPID].isin(snp_list)
        result_df.loc[mask, "CRED"] = cs_idx

    # Add PIP column
    result_df["PIP"] = result_df[ColName.SNPID].map(self.pips.to_dict()).fillna(0)

    # Sort by PIP descending
    result_df = result_df.sort_values("PIP", ascending=False)

    return result_df

from_dict(data, pips) classmethod

Create CredibleSet from dictionary and pips.

Parameters:

Name Type Description Default
data Dict[str, Any]

A dictionary containing the data to initialize the CredibleSet.

required
pips Series

Posterior inclusion probabilities.

required

Returns:

Type Description
CredibleSet

An instance of CredibleSet initialized with the provided data and pips.

Source code in credtools/credibleset.py
@classmethod
def from_dict(cls, data: Dict[str, Any], pips: pd.Series) -> "CredibleSet":
    """
    Create CredibleSet from dictionary and pips.

    Parameters
    ----------
    data : Dict[str, Any]
        A dictionary containing the data to initialize the CredibleSet.
    pips : pd.Series
        Posterior inclusion probabilities.

    Returns
    -------
    CredibleSet
        An instance of CredibleSet initialized with the provided data and pips.
    """
    return cls(
        tool=data["tool"],
        parameters=data["parameters"],
        coverage=data["coverage"],
        n_cs=data["n_cs"],
        cs_sizes=data["cs_sizes"],
        lead_snps=data["lead_snps"],
        snps=data["snps"],
        pips=pips,
        purity=data.get("purity"),
        converged=data.get("converged"),
        n_iter=data.get("n_iter"),
    )

set_per_locus_results(per_locus_results)

Attach per-locus credible set results.

Source code in credtools/credibleset.py
def set_per_locus_results(
    self, per_locus_results: Dict[str, "CredibleSet"]
) -> None:
    """Attach per-locus credible set results."""
    self._per_locus_results = per_locus_results

to_dict()

Convert to dictionary for TOML storage (excluding pips).

Returns:

Type Description
Dict[str, Any]

A dictionary representation of the CredibleSet excluding pips.

Source code in credtools/credibleset.py
def to_dict(self) -> Dict[str, Any]:
    """
    Convert to dictionary for TOML storage (excluding pips).

    Returns
    -------
    Dict[str, Any]
        A dictionary representation of the CredibleSet excluding pips.
    """
    return {
        "tool": self.tool,
        "n_cs": self.n_cs,
        "coverage": self.coverage,
        "lead_snps": self.lead_snps,
        "snps": self.snps,
        "cs_sizes": self.cs_sizes,
        "parameters": self.parameters,
        "purity": self.purity,
        "converged": self.converged,
        "n_iter": self.n_iter,
    }

calculate_cs_purity(ld, cs_snp_ids)

Calculate purity for a single credible set.

Purity is defined as the minimum absolute LD R value between all pairs of SNPs in the credible set.

For multiple LD matrices (multi-ancestry case), purity is calculated as: 1. Extract CS submatrix from each LD matrix 2. Take element-wise maximum of absolute values across all matrices 3. Return the minimum value from the resulting meta-LD matrix

This approach (similar to MultiSuSiE) ensures the credible set has high purity across all populations.

Parameters:

Name Type Description Default
ld LDMatrix or List[LDMatrix]

LDMatrix object(s) containing both r matrix and map with SNPIDs. If a list is provided, meta-purity across all matrices is calculated.

required
cs_snp_ids List[str]

List of SNPID strings in the credible set.

required

Returns:

Type Description
Optional[float]
  • If CS has only 1 SNP, returns 1.0
  • If CS has multiple SNPs, returns min(|R|) for all SNP pairs
  • For multiple LD matrices, returns min of element-wise max across matrices
  • If unable to calculate (e.g., SNPs not in LD matrix), returns None

Examples:

>>> # Single LD matrix: CS with 3 SNPs having LD R values: 0.8, 0.9, 0.7
>>> # Purity = min(|0.8|, |0.9|, |0.7|) = 0.7
>>>
>>> # Multiple LD matrices: same CS in EUR and AFR
>>> # EUR: |R| values = [0.8, 0.9, 0.7]
>>> # AFR: |R| values = [0.6, 0.85, 0.75]
>>> # Meta |R| = max([0.8, 0.9, 0.7], [0.6, 0.85, 0.75]) = [0.8, 0.9, 0.75]
>>> # Purity = min([0.8, 0.9, 0.75]) = 0.75
Source code in credtools/credibleset.py
def calculate_cs_purity(
    ld: Union["LDMatrix", List["LDMatrix"]],
    cs_snp_ids: List[str],
) -> Optional[float]:
    """
    Calculate purity for a single credible set.

    Purity is defined as the minimum absolute LD R value between all pairs of
    SNPs in the credible set.

    For multiple LD matrices (multi-ancestry case), purity is calculated as:
    1. Extract CS submatrix from each LD matrix
    2. Take element-wise maximum of absolute values across all matrices
    3. Return the minimum value from the resulting meta-LD matrix

    This approach (similar to MultiSuSiE) ensures the credible set has high
    purity across all populations.

    Parameters
    ----------
    ld : LDMatrix or List[LDMatrix]
        LDMatrix object(s) containing both r matrix and map with SNPIDs.
        If a list is provided, meta-purity across all matrices is calculated.
    cs_snp_ids : List[str]
        List of SNPID strings in the credible set.

    Returns
    -------
    Optional[float]
        - If CS has only 1 SNP, returns 1.0
        - If CS has multiple SNPs, returns min(|R|) for all SNP pairs
        - For multiple LD matrices, returns min of element-wise max across matrices
        - If unable to calculate (e.g., SNPs not in LD matrix), returns None

    Examples
    --------
    >>> # Single LD matrix: CS with 3 SNPs having LD R values: 0.8, 0.9, 0.7
    >>> # Purity = min(|0.8|, |0.9|, |0.7|) = 0.7
    >>>
    >>> # Multiple LD matrices: same CS in EUR and AFR
    >>> # EUR: |R| values = [0.8, 0.9, 0.7]
    >>> # AFR: |R| values = [0.6, 0.85, 0.75]
    >>> # Meta |R| = max([0.8, 0.9, 0.7], [0.6, 0.85, 0.75]) = [0.8, 0.9, 0.75]
    >>> # Purity = min([0.8, 0.9, 0.75]) = 0.75
    """
    from credtools.constants import ColName

    if len(cs_snp_ids) == 1:
        return 1.0

    # Handle single LD matrix vs list of LD matrices
    if not isinstance(ld, list):
        ld_list = [ld]
    else:
        ld_list = ld

    if len(ld_list) == 0:
        return None

    # Create union of all CS SNPs across all LD matrices (MultiSuSiE approach)
    # This ensures all submatrices have the same dimensions
    union_snps = []
    for snpid in cs_snp_ids:
        # Check if SNP appears in at least one LD matrix
        for ld_matrix in ld_list:
            snpid_to_idx = {
                snpid: i for i, snpid in enumerate(ld_matrix.map[ColName.SNPID])
            }
            if snpid in snpid_to_idx:
                union_snps.append(snpid)
                break

    # Remove duplicates while preserving order
    seen = set()
    union_snps = [x for x in union_snps if not (x in seen or seen.add(x))]

    if len(union_snps) < 2:
        # Not enough SNPs found in any LD matrix
        return None

    # Create mapping from SNP ID to index in union set
    variant_to_index = {snpid: i for i, snpid in enumerate(union_snps)}
    n_union = len(union_snps)

    # Extract and expand CS submatrices from all LD matrices
    cs_submatrices = []
    for ld_matrix in ld_list:
        # Create SNPID to index mapping for this LD matrix
        snpid_to_idx = {
            snpid: i for i, snpid in enumerate(ld_matrix.map[ColName.SNPID])
        }

        # Initialize expanded LD matrix with zeros
        expand_ld = np.zeros((n_union, n_union), dtype=np.float32)

        # Find SNPs that exist in both union and this LD matrix
        present_snps = [snpid for snpid in union_snps if snpid in snpid_to_idx]

        if len(present_snps) >= 2:
            # Get indices in this LD matrix
            ld_indices = np.array([snpid_to_idx[snpid] for snpid in present_snps])
            # Get indices in union set
            union_indices = np.array(
                [variant_to_index[snpid] for snpid in present_snps]
            )

            # Extract submatrix for present SNPs from this LD matrix
            ld_submatrix = ld_matrix.r[np.ix_(ld_indices, ld_indices)]

            # Place LD values at correct positions in expanded matrix using meshgrid
            idx_i, idx_j = np.meshgrid(union_indices, union_indices)
            expand_ld[idx_i, idx_j] = ld_submatrix.astype(np.float32)

        # Set diagonal to 1 (for both present and missing SNPs)
        np.fill_diagonal(expand_ld, 1.0)

        cs_submatrices.append(expand_ld)

    if len(cs_submatrices) == 0:
        # No valid submatrices found
        return None

    # Calculate meta-purity across all LD matrices
    # Take element-wise maximum of absolute values (MultiSuSiE approach)
    abs_meta_R = np.abs(cs_submatrices[0])
    for submatrix in cs_submatrices[1:]:
        abs_meta_R = np.maximum(abs_meta_R, np.abs(submatrix))

    # Get upper triangle (excluding diagonal) and find minimum
    upper_tri_indices = np.triu_indices_from(abs_meta_R, k=1)
    r_values = abs_meta_R[upper_tri_indices]

    if len(r_values) == 0:
        return None

    return float(np.min(r_values))

cluster_cs(dict_sets, threshold=0.9)

Cluster dictionaries from different sets based on continuous Jaccard similarity.

Parameters:

Name Type Description Default
dict_sets List[List[Dict[str, float]]]

List of m sets, where each set contains dictionaries with PIP values.

required
threshold float

Clustering threshold, by default 0.9.

0.9

Returns:

Type Description
List[List[str]]

List of merged clusters, where each cluster contains a list of unique SNP IDs from the dictionaries in that cluster.

Raises:

Type Description
ValueError

If less than two sets of dictionaries are provided or if any set is empty.

Examples:

>>> sets = [
...     [{'a': 0.8, 'b': 0.5}],
...     [{'b': 0.6, 'c': 0.3}]
... ]
>>> clusters = cluster_cs(sets)
Source code in credtools/credibleset.py
def cluster_cs(
    dict_sets: List[List[Dict[str, float]]], threshold: float = 0.9
) -> List[List[str]]:
    """
    Cluster dictionaries from different sets based on continuous Jaccard similarity.

    Parameters
    ----------
    dict_sets : List[List[Dict[str, float]]]
        List of m sets, where each set contains dictionaries with PIP values.
    threshold : float, optional
        Clustering threshold, by default 0.9.

    Returns
    -------
    List[List[str]]
        List of merged clusters, where each cluster contains
        a list of unique SNP IDs from the dictionaries in that cluster.

    Raises
    ------
    ValueError
        If less than two sets of dictionaries are provided or if any set is empty.

    Examples
    --------
    >>> sets = [
    ...     [{'a': 0.8, 'b': 0.5}],
    ...     [{'b': 0.6, 'c': 0.3}]
    ... ]
    >>> clusters = cluster_cs(sets)
    """
    if len(dict_sets) < 2:
        raise ValueError("At least two sets of dictionaries are required")

    # Validate input
    for dict_set in dict_sets:
        if not dict_set:
            raise ValueError("Empty dictionary sets are not allowed")

    # Create similarity matrix
    similarity_matrix, all_dicts = create_similarity_matrix(dict_sets)

    # Convert similarity to distance (1 - similarity)
    distance_matrix = 1 - similarity_matrix

    # Perform hierarchical clustering
    condensed_dist = distance_matrix[np.triu_indices(len(distance_matrix), k=1)]

    if len(condensed_dist) == 0:
        logger.warning("No valid distances found for clustering")
        return [list(set(all_dicts[0].keys()))]

    linkage_matrix = linkage(condensed_dist, method="average")

    # Cut the dendrogram at the specified threshold
    clusters = fcluster(linkage_matrix, threshold, criterion="distance")

    # Group dictionaries by cluster and merge them
    cluster_groups: Dict[int, List[str]] = {}
    for idx, cluster_id in enumerate(clusters):
        if cluster_id not in cluster_groups:
            cluster_groups[cluster_id] = []

            # Merge dictionaries within cluster by merging keys (no PIP values) and removing duplicates
            current_dict = all_dicts[idx]
            cluster_groups[cluster_id].extend(current_dict.keys())

    return [
        list(set(cluster_groups[cluster_id])) for cluster_id in sorted(cluster_groups)
    ]

combine_creds(creds, combine_cred='union', combine_pip='max', jaccard_threshold=0.1, ld_matrices=None, min_purity=0.0)

Combine credible sets from multiple tools.

Parameters:

Name Type Description Default
creds List[CredibleSet]

List of credible sets from multiple tools.

required
combine_cred str

Method to combine credible sets, by default "union". Options: "union", "intersection", "cluster".

  • "union": Union of all credible sets to form a merged credible set.
  • "intersection": First merge the credible sets from the same tool, then take the intersection of all merged credible sets. No credible set will be returned if no common SNPs found.
  • "cluster": Merge credible sets with Jaccard index > jaccard_threshold.
'union'
combine_pip str

Method to combine PIPs, by default "max". Options: "max", "min", "mean", "meta".

  • "meta": PIP_meta = 1 - prod(1 - PIP_i), where i is the index of tools, PIP_i = 0 when the SNP is not in the credible set of the tool.
  • "max": Maximum PIP value for each SNP across all tools.
  • "min": Minimum PIP value for each SNP across all tools.
  • "mean": Mean PIP value for each SNP across all tools.
'max'
jaccard_threshold float

Jaccard index threshold for the "cluster" method, by default 0.1.

0.1
ld_matrices Optional[List[LDMatrix]]

List of LD matrices for purity calculation, by default None. If provided, purity will be calculated for merged credible sets using multi-ancestry approach (element-wise max across populations). If None, purity will not be calculated for the merged credible sets.

None
min_purity float

Minimum purity threshold for filtering credible sets, by default 0.0. After combining credible sets, only those with purity >= min_purity will be kept. Purity is the minimum absolute LD R value between all SNP pairs in a credible set. Set to 0.0 (default) for no filtering.

0.0

Returns:

Type Description
CredibleSet

Combined credible set.

Raises:

Type Description
ValueError

If the method is not supported.

Notes

'union' and 'intersection' methods will merge all credible sets into one.

Source code in credtools/credibleset.py
def combine_creds(
    creds: List[CredibleSet],
    combine_cred: str = "union",
    combine_pip: str = "max",
    jaccard_threshold: float = 0.1,
    ld_matrices: Optional[List["LDMatrix"]] = None,
    min_purity: float = 0.0,
) -> CredibleSet:
    """
    Combine credible sets from multiple tools.

    Parameters
    ----------
    creds : List[CredibleSet]
        List of credible sets from multiple tools.
    combine_cred : str, optional
        Method to combine credible sets, by default "union".
        Options: "union", "intersection", "cluster".

        - "union": Union of all credible sets to form a merged credible set.
        - "intersection": First merge the credible sets from the same tool,
            then take the intersection of all merged credible sets.
            No credible set will be returned if no common SNPs found.
        - "cluster": Merge credible sets with Jaccard index > jaccard_threshold.
    combine_pip : str, optional
        Method to combine PIPs, by default "max".
        Options: "max", "min", "mean", "meta".

        - "meta": PIP_meta = 1 - prod(1 - PIP_i), where i is the index of tools,
            PIP_i = 0 when the SNP is not in the credible set of the tool.
        - "max": Maximum PIP value for each SNP across all tools.
        - "min": Minimum PIP value for each SNP across all tools.
        - "mean": Mean PIP value for each SNP across all tools.
    jaccard_threshold : float, optional
        Jaccard index threshold for the "cluster" method, by default 0.1.
    ld_matrices : Optional[List[LDMatrix]], optional
        List of LD matrices for purity calculation, by default None.
        If provided, purity will be calculated for merged credible sets using
        multi-ancestry approach (element-wise max across populations).
        If None, purity will not be calculated for the merged credible sets.
    min_purity : float, optional
        Minimum purity threshold for filtering credible sets, by default 0.0.
        After combining credible sets, only those with purity >= min_purity will be kept.
        Purity is the minimum absolute LD R value between all SNP pairs in a credible set.
        Set to 0.0 (default) for no filtering.

    Returns
    -------
    CredibleSet
        Combined credible set.

    Raises
    ------
    ValueError
        If the method is not supported.

    Notes
    -----
    'union' and 'intersection' methods will merge all credible sets into one.
    """
    paras = creds[0].parameters
    tool = creds[0].tool
    # filter out the creds with no credible set
    creds = [cred for cred in creds if cred.n_cs > 0]
    if len(creds) == 0:
        logger.warning("No credible sets found in the input list.")
        return CredibleSet(
            tool=tool,
            n_cs=0,
            coverage=0,
            lead_snps=[],
            snps=[],
            cs_sizes=[],
            pips=pd.Series(),
            parameters=paras,
        )
    if len(creds) == 1:
        return creds[0]
    if combine_cred == "union":
        merged_snps_flat = []
        for cred in creds:
            snps = [i for snp in cred.snps for i in snp]
            merged_snps_flat.extend(snps)
        merged_snps = [list(set(merged_snps_flat))]
    elif combine_cred == "intersection":
        merged_snps_set = None
        for i, cred in enumerate(creds):
            snps = [item for snp in cred.snps for item in snp]
            if i == 0:
                merged_snps_set = set(snps)
            else:
                if merged_snps_set is not None:
                    merged_snps_set.intersection_update(set(snps))
        if merged_snps_set is None or len(merged_snps_set) == 0:
            logger.warning("No common SNPs found in the intersection of credible sets.")
            merged_snps = [[]]
        else:
            merged_snps = [list(merged_snps_set)]
    elif combine_cred == "cluster":
        cred_pips = []
        for cred in creds:
            cred_pip = [dict(cred.pips[cred.pips.index.isin(snp)]) for snp in cred.snps]
            cred_pips.append(cred_pip)
        merged_snps = cluster_cs(cred_pips, 1 - jaccard_threshold)
        paras["jaccard_threshold"] = jaccard_threshold
    else:
        raise ValueError(f"Method {combine_cred} is not supported.")
    merged_pips = combine_pips([cred.pips for cred in creds], combine_pip)
    paras["combine_cred"] = combine_cred
    paras["combine_pip"] = combine_pip

    # Calculate purity for merged credible sets if LD matrices are provided
    purity = None
    if ld_matrices is not None and len(ld_matrices) > 0:
        purity = []
        for snp_list in merged_snps:
            if len(snp_list) > 0:
                purity_val = calculate_cs_purity(ld_matrices, snp_list)
                purity.append(purity_val)
            else:
                purity.append(None)
        logger.info(
            f"Calculated purity for {len(purity)} merged credible sets: {purity}"
        )

    merged = CredibleSet(
        tool=creds[0].tool,
        n_cs=len(merged_snps),
        coverage=creds[0].coverage,
        lead_snps=[
            str(merged_pips[merged_pips.index.isin(snp)].idxmax())
            for snp in merged_snps
        ],
        snps=merged_snps,
        cs_sizes=[len(i) for i in merged_snps],
        pips=merged_pips,
        parameters=paras,
        purity=purity,
    )

    # Apply purity filtering if requested
    if min_purity > 0:
        merged = filter_credset_by_purity(merged, min_purity=min_purity)

    return merged

combine_pips(pips, method='max')

Combine PIPs from multiple tools.

Parameters:

Name Type Description Default
pips List[Series]

List of PIPs from multiple tools.

required
method str

Method to combine PIPs, by default "max". Options: "max", "min", "mean", "meta". When "meta" is selected, the method will use the formula: PIP_meta = 1 - prod(1 - PIP_i), where i is the index of tools, PIP_i = 0 when the SNP is not in the credible set of the tool. When "max", "min", "mean" is selected, the SNP not in the credible set will be excluded from the calculation.

'max'

Returns:

Type Description
Series

Combined PIPs.

Raises:

Type Description
ValueError

If the method is not supported.

Source code in credtools/credibleset.py
def combine_pips(pips: List[pd.Series], method: str = "max") -> pd.Series:
    """
    Combine PIPs from multiple tools.

    Parameters
    ----------
    pips : List[pd.Series]
        List of PIPs from multiple tools.
    method : str, optional
        Method to combine PIPs, by default "max".
        Options: "max", "min", "mean", "meta".
        When "meta" is selected, the method will use the formula:
        PIP_meta = 1 - prod(1 - PIP_i), where i is the index of tools,
        PIP_i = 0 when the SNP is not in the credible set of the tool.
        When "max", "min", "mean" is selected, the SNP not in the credible set
        will be excluded from the calculation.

    Returns
    -------
    pd.Series
        Combined PIPs.

    Raises
    ------
    ValueError
        If the method is not supported.
    """
    logger.info(f"Combining PIPs using method: {method}")
    pip_df = pd.DataFrame(pips).T
    pip_df = pip_df.fillna(0)
    if method == "meta":
        merged = 1 - np.prod(1 - pip_df, axis=1)
    elif method == "max":
        merged = pip_df.max(axis=1)
    elif method == "min":
        merged = pip_df.min(axis=1)
    elif method == "mean":
        merged = pip_df.mean(axis=1)
    else:
        raise ValueError(f"Method {method} is not supported.")
    return merged

continuous_jaccard(dict1, dict2)

Calculate modified Jaccard similarity for continuous values (PIP values).

Formula: ∑min(xi,yi)/∑max(xi,yi) where xi, yi are PIP values or 0 if missing

Citation: Yuan, K. et al. (2024) Nature Genetics https://doi.org/10.1038/s41588-024-01870-z.

Parameters:

Name Type Description Default
dict1 Dict[str, float]

First dictionary with keys and PIP values (0-1).

required
dict2 Dict[str, float]

Second dictionary with keys and PIP values (0-1).

required

Returns:

Type Description
float

Modified Jaccard similarity index between 0 and 1.

Raises:

Type Description
ValueError

If any values are not between 0 and 1.

Examples:

>>> d1 = {'a': 0.8, 'b': 0.5}
>>> d2 = {'b': 0.6, 'c': 0.3}
>>> continuous_jaccard(d1, d2)
0.5
Source code in credtools/credibleset.py
def continuous_jaccard(dict1: Dict[str, float], dict2: Dict[str, float]) -> float:
    """
    Calculate modified Jaccard similarity for continuous values (PIP values).

    Formula: ∑min(xi,yi)/∑max(xi,yi) where xi, yi are PIP values or 0 if missing

    Citation: Yuan, K. et al. (2024) Nature Genetics https://doi.org/10.1038/s41588-024-01870-z.

    Parameters
    ----------
    dict1 : Dict[str, float]
        First dictionary with keys and PIP values (0-1).
    dict2 : Dict[str, float]
        Second dictionary with keys and PIP values (0-1).

    Returns
    -------
    float
        Modified Jaccard similarity index between 0 and 1.

    Raises
    ------
    ValueError
        If any values are not between 0 and 1.

    Examples
    --------
    >>> d1 = {'a': 0.8, 'b': 0.5}
    >>> d2 = {'b': 0.6, 'c': 0.3}
    >>> continuous_jaccard(d1, d2)
    0.5
    """
    # Validate input values
    for d in [dict1, dict2]:
        invalid_values = [v for v in d.values() if not (0 <= v <= 1)]
        if invalid_values:
            raise ValueError("All values must be between 0 and 1")

    # Get all keys
    all_keys = set(dict1.keys()).union(set(dict2.keys()))

    # Calculate sum of minimums and maximums
    sum_min = 0.0
    sum_max = 0.0

    for key in all_keys:
        val1 = dict1.get(key, 0.0)
        val2 = dict2.get(key, 0.0)
        sum_min += min(val1, val2)
        sum_max += max(val1, val2)

    return sum_min / sum_max if sum_max > 0 else 0.0

create_similarity_matrix(dict_sets)

Create a similarity matrix for all pairs of dictionaries across different sets.

Parameters:

Name Type Description Default
dict_sets List[List[Dict[str, float]]]

List of m sets, where each set contains dictionaries with PIP values.

required

Returns:

Type Description
Tuple[ndarray, List[Dict[str, float]]]

A tuple containing: - Similarity matrix (n_dicts x n_dicts) - Flattened list of dictionaries

Examples:

>>> sets = [[{'a': 0.8, 'b': 0.5}], [{'b': 0.6, 'c': 0.3}]]
>>> matrix, dicts = create_similarity_matrix(sets)
Source code in credtools/credibleset.py
def create_similarity_matrix(
    dict_sets: List[List[Dict[str, float]]],
) -> Tuple[np.ndarray, List[Dict[str, float]]]:
    """
    Create a similarity matrix for all pairs of dictionaries across different sets.

    Parameters
    ----------
    dict_sets : List[List[Dict[str, float]]]
        List of m sets, where each set contains dictionaries with PIP values.

    Returns
    -------
    Tuple[np.ndarray, List[Dict[str, float]]]
        A tuple containing:
        - Similarity matrix (n_dicts x n_dicts)
        - Flattened list of dictionaries

    Examples
    --------
    >>> sets = [[{'a': 0.8, 'b': 0.5}], [{'b': 0.6, 'c': 0.3}]]
    >>> matrix, dicts = create_similarity_matrix(sets)
    """
    # Flatten all dictionaries while keeping track of their set membership
    all_dicts = []
    for dict_set in dict_sets:
        all_dicts.extend(dict_set)

    total_dicts = len(all_dicts)

    # Create similarity matrix
    similarity_matrix = np.zeros((total_dicts, total_dicts))

    # Calculate set membership ranges
    set_ranges = []
    current_idx = 0
    for dict_set in dict_sets:
        set_ranges.append((current_idx, current_idx + len(dict_set)))
        current_idx += len(dict_set)

    # Fill similarity matrix
    for i, j in combinations(range(total_dicts), 2):
        # Check if dictionaries are from the same set
        same_set = False
        for start, end in set_ranges:
            if start <= i < end and start <= j < end:
                same_set = True
                break

        if not same_set:
            similarity = continuous_jaccard(all_dicts[i], all_dicts[j])
            similarity_matrix[i, j] = similarity
            similarity_matrix[j, i] = similarity

    return similarity_matrix, all_dicts

filter_credset_by_purity(credset, min_purity=0.0)

Filter credible sets by purity threshold.

Removes credible sets that do not meet the minimum purity requirement. Purity is defined as the minimum absolute LD R value between all pairs of SNPs in the credible set.

Parameters:

Name Type Description Default
credset CredibleSet

CredibleSet object containing credible sets and their purity values.

required
min_purity float

Minimum purity threshold for filtering, by default 0.0. Credible sets with purity < min_purity will be removed. Set to 0.0 (default) for no filtering.

0.0

Returns:

Type Description
CredibleSet

New CredibleSet object with only credible sets meeting purity threshold. If no credible sets pass filtering, returns empty CredibleSet (n_cs=0).

Notes
  • If credset.purity is None or empty, no filtering is applied (returns original credset)
  • If min_purity <= 0, no filtering is applied (returns original credset)
  • Filtered credible sets maintain their original ordering
  • PIPs are preserved for all variants (not filtered)

Examples:

>>> # Filter credible sets to keep only high-purity sets (purity >= 0.5)
>>> filtered_cs = filter_credset_by_purity(credset, min_purity=0.5)
>>> print(f"Original: {credset.n_cs} CS, Filtered: {filtered_cs.n_cs} CS")
Original: 5 CS, Filtered: 3 CS
>>> # No filtering (default)
>>> same_cs = filter_credset_by_purity(credset, min_purity=0.0)
>>> assert same_cs.n_cs == credset.n_cs
Source code in credtools/credibleset.py
def filter_credset_by_purity(
    credset: "CredibleSet",
    min_purity: float = 0.0,
) -> "CredibleSet":
    """
    Filter credible sets by purity threshold.

    Removes credible sets that do not meet the minimum purity requirement.
    Purity is defined as the minimum absolute LD R value between all pairs
    of SNPs in the credible set.

    Parameters
    ----------
    credset : CredibleSet
        CredibleSet object containing credible sets and their purity values.
    min_purity : float, optional
        Minimum purity threshold for filtering, by default 0.0.
        Credible sets with purity < min_purity will be removed.
        Set to 0.0 (default) for no filtering.

    Returns
    -------
    CredibleSet
        New CredibleSet object with only credible sets meeting purity threshold.
        If no credible sets pass filtering, returns empty CredibleSet (n_cs=0).

    Notes
    -----
    - If credset.purity is None or empty, no filtering is applied (returns original credset)
    - If min_purity <= 0, no filtering is applied (returns original credset)
    - Filtered credible sets maintain their original ordering
    - PIPs are preserved for all variants (not filtered)

    Examples
    --------
    >>> # Filter credible sets to keep only high-purity sets (purity >= 0.5)
    >>> filtered_cs = filter_credset_by_purity(credset, min_purity=0.5)
    >>> print(f"Original: {credset.n_cs} CS, Filtered: {filtered_cs.n_cs} CS")
    Original: 5 CS, Filtered: 3 CS

    >>> # No filtering (default)
    >>> same_cs = filter_credset_by_purity(credset, min_purity=0.0)
    >>> assert same_cs.n_cs == credset.n_cs
    """
    # No filtering if min_purity <= 0
    if min_purity <= 0:
        return credset

    # No filtering if purity values are not available
    if credset.purity is None or len(credset.purity) == 0:
        logger.warning(
            "Purity values not available for filtering. "
            "Returning original credible set without filtering."
        )
        return credset

    # No credible sets to filter
    if credset.n_cs == 0:
        return credset

    # Filter credible sets by purity threshold
    keep_indices = []
    for i, purity_val in enumerate(credset.purity):
        if purity_val is not None and purity_val >= min_purity:
            keep_indices.append(i)

    # If no credible sets pass filtering, return empty CredibleSet
    if len(keep_indices) == 0:
        logger.warning(
            f"No credible sets passed purity filtering (min_purity={min_purity}). "
            f"All {credset.n_cs} credible sets were filtered out."
        )
        return CredibleSet(
            tool=credset.tool,
            n_cs=0,
            coverage=credset.coverage,
            lead_snps=[],
            snps=[],
            cs_sizes=[],
            pips=credset.pips,
            parameters=credset.parameters,
            purity=[],
        )

    # Filter credible sets
    filtered_snps = [credset.snps[i] for i in keep_indices]
    filtered_lead_snps = [credset.lead_snps[i] for i in keep_indices]
    filtered_cs_sizes = [credset.cs_sizes[i] for i in keep_indices]
    filtered_purity = [credset.purity[i] for i in keep_indices]

    logger.info(
        f"Filtered credible sets by purity >= {min_purity}: "
        f"{credset.n_cs}{len(keep_indices)} credible sets"
    )

    return CredibleSet(
        tool=credset.tool,
        n_cs=len(keep_indices),
        coverage=credset.coverage,
        lead_snps=filtered_lead_snps,
        snps=filtered_snps,
        cs_sizes=filtered_cs_sizes,
        pips=credset.pips,  # Keep all PIPs (not filtered)
        parameters=credset.parameters,
        purity=filtered_purity,
    )

generate_cs_summary(causal_variants, locus_id, locus_set)

Generate credible set summary from causal variants.

Creates one summary row per credible set, including lead SNP, size, PIP thresholds, and purity metrics.

Parameters:

Name Type Description Default
causal_variants DataFrame

DataFrame of causal variants (rows where CRED != 0), must have columns: CRED, PIP, SNPID.

required
locus_id str

Locus identifier string.

required
locus_set LocusSet

LocusSet object for LD-based purity calculation.

required

Returns:

Type Description
List[Dict]

List of summary dictionaries, one per credible set.

Source code in credtools/credibleset.py
def generate_cs_summary(
    causal_variants: pd.DataFrame,
    locus_id: str,
    locus_set: "LocusSet",
) -> List[Dict]:
    """
    Generate credible set summary from causal variants.

    Creates one summary row per credible set, including lead SNP, size,
    PIP thresholds, and purity metrics.

    Parameters
    ----------
    causal_variants : pd.DataFrame
        DataFrame of causal variants (rows where CRED != 0), must have
        columns: CRED, PIP, SNPID.
    locus_id : str
        Locus identifier string.
    locus_set : LocusSet
        LocusSet object for LD-based purity calculation.

    Returns
    -------
    List[Dict]
        List of summary dictionaries, one per credible set.
    """
    if causal_variants.empty:
        return []

    cs_summary_list = []
    for cs_id in sorted(causal_variants["CRED"].unique()):
        cs_snps = causal_variants[causal_variants["CRED"] == cs_id]
        lead_snp_idx = cs_snps["PIP"].idxmax()
        lead_snp = cs_snps.loc[lead_snp_idx, "SNPID"]
        cs_size = len(cs_snps)
        pip_01 = int((cs_snps["PIP"] >= 0.1).sum())
        pip_05 = int((cs_snps["PIP"] >= 0.5).sum())
        pip_09 = int((cs_snps["PIP"] >= 0.9).sum())

        # Calculate purity if LD is available
        purity = None
        ld_list = [locus.ld for locus in locus_set.loci if locus.ld is not None]
        if len(ld_list) > 0:
            cs_snp_ids = cs_snps["SNPID"].tolist()
            purity = calculate_cs_purity(ld_list, cs_snp_ids)

        cs_summary_list.append(
            {
                "locus_id": locus_id,
                "cs_id": int(cs_id),
                "lead_snp": lead_snp,
                "cs_size": cs_size,
                "pip_01": pip_01,
                "pip_05": pip_05,
                "pip_09": pip_09,
                "purity": purity,
            }
        )

    return cs_summary_list

COJO Helpers

Wrapper for COJO.

conditional_selection(locus, p_cutoff=5e-08, collinear_cutoff=0.9, window_size=10000000, maf_cutoff=0.01, diff_freq_cutoff=0.2)

Perform conditional selection on the locus using COJO method.

Parameters:

Name Type Description Default
locus Locus

The locus to perform conditional selection on. Must contain summary statistics and LD matrix data.

required
p_cutoff float

The p-value cutoff for the conditional selection, by default 5e-8. If no SNPs pass this threshold, it will be relaxed to 1e-5.

5e-08
collinear_cutoff float

The collinearity cutoff for the conditional selection, by default 0.9. SNPs with LD correlation above this threshold are considered collinear.

0.9
window_size int

The window size in base pairs for the conditional selection, by default 10000000. SNPs within this window are considered for conditional analysis.

10000000
maf_cutoff float

The minor allele frequency cutoff for the conditional selection, by default 0.01. SNPs with MAF below this threshold are excluded.

0.01
diff_freq_cutoff float

The difference in frequency cutoff between summary statistics and reference panel, by default 0.2. SNPs with frequency differences above this threshold are excluded.

0.2

Returns:

Type Description
DataFrame

The conditional selection results containing independently associated variants with columns including SNP identifiers, effect sizes, and conditional p-values.

Warnings

If no SNPs pass the initial p-value cutoff, the threshold is automatically relaxed to 1e-5 and a warning is logged.

If AF2 (reference allele frequency) is not available in the LD matrix, a warning is logged and frequency checking is disabled.

Notes

COJO (Conditional and Joint analysis) performs stepwise conditional analysis to identify independently associated variants at a locus. The method:

  1. Identifies the most significant SNP
  2. Performs conditional analysis on remaining SNPs
  3. Iteratively adds independently associated SNPs
  4. Continues until no more SNPs meet significance criteria

The algorithm accounts for linkage disequilibrium patterns and helps distinguish truly independent signals from those in LD with lead variants.

Reference: Yang, J. et al. Conditional and joint multiple-SNP analysis of GWAS summary statistics identifies additional variants influencing complex traits. Nat Genet 44, 369-375 (2012).

Examples:

>>> # Basic conditional selection
>>> results = conditional_selection(locus)
>>> print(f"Found {len(results)} independent signals")
Found 3 independent signals
>>> # With custom thresholds
>>> results = conditional_selection(
...     locus,
...     p_cutoff=1e-6,
...     maf_cutoff=0.05
... )
>>> print(results[['SNP', 'b', 'se', 'p']])
    SNP           b        se         p
0   rs123456   0.15   0.025   1.2e-08
1   rs789012  -0.08   0.020   4.5e-07
Source code in credtools/cojo.py
def conditional_selection(
    locus: Locus,
    p_cutoff: float = 5e-8,
    collinear_cutoff: float = 0.9,
    window_size: int = 10000000,
    maf_cutoff: float = 0.01,
    diff_freq_cutoff: float = 0.2,
) -> pd.DataFrame:
    """
    Perform conditional selection on the locus using COJO method.

    Parameters
    ----------
    locus : Locus
        The locus to perform conditional selection on. Must contain summary statistics
        and LD matrix data.
    p_cutoff : float, optional
        The p-value cutoff for the conditional selection, by default 5e-8.
        If no SNPs pass this threshold, it will be relaxed to 1e-5.
    collinear_cutoff : float, optional
        The collinearity cutoff for the conditional selection, by default 0.9.
        SNPs with LD correlation above this threshold are considered collinear.
    window_size : int, optional
        The window size in base pairs for the conditional selection, by default 10000000.
        SNPs within this window are considered for conditional analysis.
    maf_cutoff : float, optional
        The minor allele frequency cutoff for the conditional selection, by default 0.01.
        SNPs with MAF below this threshold are excluded.
    diff_freq_cutoff : float, optional
        The difference in frequency cutoff between summary statistics and reference panel,
        by default 0.2. SNPs with frequency differences above this threshold are excluded.

    Returns
    -------
    pd.DataFrame
        The conditional selection results containing independently associated variants
        with columns including SNP identifiers, effect sizes, and conditional p-values.

    Warnings
    --------
    If no SNPs pass the initial p-value cutoff, the threshold is automatically
    relaxed to 1e-5 and a warning is logged.

    If AF2 (reference allele frequency) is not available in the LD matrix,
    a warning is logged and frequency checking is disabled.

    Notes
    -----
    COJO (Conditional and Joint analysis) performs stepwise conditional analysis
    to identify independently associated variants at a locus. The method:

    1. Identifies the most significant SNP
    2. Performs conditional analysis on remaining SNPs
    3. Iteratively adds independently associated SNPs
    4. Continues until no more SNPs meet significance criteria

    The algorithm accounts for linkage disequilibrium patterns and helps
    distinguish truly independent signals from those in LD with lead variants.

    Reference: Yang, J. et al. Conditional and joint multiple-SNP analysis of GWAS
    summary statistics identifies additional variants influencing complex traits.
    Nat Genet 44, 369-375 (2012).

    Examples
    --------
    >>> # Basic conditional selection
    >>> results = conditional_selection(locus)
    >>> print(f"Found {len(results)} independent signals")
    Found 3 independent signals

    >>> # With custom thresholds
    >>> results = conditional_selection(
    ...     locus,
    ...     p_cutoff=1e-6,
    ...     maf_cutoff=0.05
    ... )
    >>> print(results[['SNP', 'b', 'se', 'p']])
        SNP           b        se         p
    0   rs123456   0.15   0.025   1.2e-08
    1   rs789012  -0.08   0.020   4.5e-07
    """
    sumstats = locus.sumstats.copy()
    sumstats = sumstats[
        [
            ColName.SNPID,
            ColName.EA,
            ColName.NEA,
            ColName.BETA,
            ColName.SE,
            ColName.P,
            ColName.EAF,
        ]
    ]
    sumstats.columns = ["SNP", "A1", "A2", "b", "se", "p", "freq"]
    sumstats["N"] = locus.sample_size
    if p_cutoff < 1e-5 and len(sumstats[sumstats["p"] < p_cutoff]) == 0:
        logger.warning("No SNPs passed the p-value cutoff, using p_cutoff=1e-5")
        p_cutoff = 1e-5

    ld_matrix = locus.ld.r.copy()
    ld_freq: Optional[pd.DataFrame] = locus.ld.map.copy()
    if ld_freq is not None and "AF2" not in ld_freq.columns:
        logger.warning("AF2 is not in the LD matrix.")
        ld_freq = None
    elif ld_freq is not None:
        ld_freq = ld_freq[["SNPID", "AF2"]]
        ld_freq.columns = ["SNP", "freq"]
        ld_freq["freq"] = 1 - ld_freq["freq"]
    c = COJO(
        p_cutoff=p_cutoff,
        collinear_cutoff=collinear_cutoff,
        window_size=window_size,
        maf_cutoff=maf_cutoff,
        diff_freq_cutoff=diff_freq_cutoff,
    )
    c.load_sumstats(sumstats=sumstats, ld_matrix=ld_matrix, ld_freq=ld_freq)  # type: ignore
    cojo_result = c.conditional_selection()
    return cojo_result