Skip to content

unified_metric_calculations

calculate_heterozygosity(maf)

Calculate heterozygosity for each locus in the genotype matrix.

Source code in fpg_observational_model/unified_metric_calculations.py
def calculate_heterozygosity(maf):
        """Calculate heterozygosity for each locus in the genotype matrix."""
        if maf.ndim == 2:
            maf = np.sum(maf == 1, axis=0) / maf.shape[0]
        else:
            maf = np.array(maf)    

        heterozygosity_func = lambda t: round(1 - (t**2 + (1-t)**2), 4)
        vfunc = np.vectorize(heterozygosity_func)
        site_heterozygosity = vfunc(maf)

        return site_heterozygosity

calculate_individual_rh(barcode_heterozygosity, bootstrap_list)

Calculate individual R_h value based on infection heterozygosity and sampled H_Mono distribution.

Source code in fpg_observational_model/unified_metric_calculations.py
def calculate_individual_rh(barcode_heterozygosity, bootstrap_list):
    """ Calculate individual R_h value based on infection heterozygosity and sampled H_Mono distribution. """
    rh_individual_dist = list(map(lambda i: (i-barcode_heterozygosity)/i if i != 0 else 0, bootstrap_list))

    rh_inferred_mean = round(np.median(rh_individual_dist), 3)

    return rh_inferred_mean

calculate_population_rh(df, monogenomic_dict, n_mono_boostraps=200)

Calculate the R_h statistic for the given sampling dataframe and IBS matrix.

Parameters:

Name Type Description Default
df

DataFrame containing infection data, needs to include effective COI information.

required
monogenomic_dict

Dictionary containing the IBS distribution data.

required
n_mono_boostraps

Number of bootstrap samples to draw for the population H_Mono estimation.

200

Logic from published inferential model (Wong et al 2022, https://doi.org/10.1093/pnasnexus/pgac187): - Identify the baseline co-transmission relatedness of unique monogenomic barcode pairs (H_Mono - 200 pairwise draw) - Identify per sample polygenomic heterozygosity by the number of Ns (H_Poly) - Calculate individual R_h = (H_Mono - H_Poly) / H_Mono - Calculate summary statistics (mean, median, std) from the individual polygenomic R_h values per site per year.

Adapting for the model: - (On hold -tested with too few infections to determine interpretation) For samples super infection samples with an effective COI=2, H_Mono (measured) can be measured directly as defined by the expectation of IBS values between the two unrelated genotypes in a mixed infection. - Replicate the bootstrap sampling of H_Mono by drawing from the IBS distribution of monogenomic samples, excluding IBS=1 values (i.e. identical barcodes) to calculate a H_Mono (inferred) value. - Polygenomic sample heterozygosity is calculated as the proportion of Ns in the barcode, assuming all alleles in an infection are detectable. Updates to make this more or less sensitive to minor alleles can be made in the generate_het_barcode function.

Source code in fpg_observational_model/unified_metric_calculations.py
def calculate_population_rh(df, monogenomic_dict, n_mono_boostraps=200):
    """
    Calculate the R_h statistic for the given sampling dataframe and IBS matrix.

    Args:
        df: DataFrame containing infection data, needs to include effective COI information.
        monogenomic_dict: Dictionary containing the IBS distribution data.
        n_mono_boostraps: Number of bootstrap samples to draw for the population H_Mono estimation.

    Logic from published inferential model (Wong et al 2022, https://doi.org/10.1093/pnasnexus/pgac187):
    - Identify the baseline co-transmission relatedness of unique monogenomic barcode pairs (H_Mono - 200 pairwise draw)
    - Identify per sample polygenomic heterozygosity by the number of Ns (H_Poly)
    - Calculate individual R_h = (H_Mono - H_Poly) / H_Mono
    - Calculate summary statistics (mean, median, std) from the individual polygenomic R_h values per site per year. 

    Adapting for the model:
    - (On hold -tested with too few infections to determine interpretation) For samples super infection samples with an effective COI=2, H_Mono (measured) can be measured directly as defined by the expectation of IBS values between the two unrelated genotypes in a mixed infection.
    - Replicate the bootstrap sampling of H_Mono by drawing from the IBS distribution of monogenomic samples, excluding IBS=1 values (i.e. identical barcodes) to calculate a H_Mono (inferred) value.
    - Polygenomic sample heterozygosity is calculated as the proportion of Ns in the barcode, assuming all alleles in an infection are detectable. Updates to make this more or less sensitive to minor alleles can be made in the generate_het_barcode function.
    """


    # coi2_superinfections = df[(df['effective_coi'] == 2) & (df['cotx'] == False)]
    # rh_mono_mean = round(coi2_superinfections['ibs_mean'].mean(), 3)
    poly_samples = df[df['effective_coi'] > 1].copy()
    if 'barcode_N_prop' not in poly_samples.columns:
        poly_samples['barcode_N_prop'] = poly_samples.apply(
            lambda row: row['barcode_with_Ns'].count('N') / len(row['barcode_with_Ns']) if isinstance(row['barcode_with_Ns'], str) else 0, axis=1
        )

    bootstrap_list = sample_from_distribution(monogenomic_dict, n_bootstraps=n_mono_boostraps)

    poly_samples['individual_inferred_rh'] = poly_samples.apply(lambda row: calculate_individual_rh(row['barcode_N_prop'], bootstrap_list), axis=1)

    rh_measurements = pd.DataFrame([{
        # 'rh_mono_count': len(coi2_superinfections),
        # 'rh_mono_measured_mean': rh_mono_mean,
        # 'rh_mono_measured_median': round(coi2_superinfections['ibs_mean'].median(), 3),
        # 'rh_mono_measured_std': round(coi2_superinfections['ibs_mean'].std(), 3),
        # 'rh_poly_count': len(poly_samples),
        # 'rh_poly_measured_mean': round(poly_samples['individual_measured_rh'].mean(), 3),
        # 'rh_poly_measured_median': round(poly_samples['individual_measured_rh'].median(), 3),
        # 'rh_poly_measured_std': round(poly_samples['individual_measured_rh'].std(), 3),
        'rh_poly_inferred_mean': round(poly_samples['individual_inferred_rh'].mean(), 3),
        'rh_poly_inferred_median': round(poly_samples['individual_inferred_rh'].median(), 3),
        'rh_poly_inferred_std': round(poly_samples['individual_inferred_rh'].std(), 3)
    }])

    return rh_measurements, poly_samples[['infIndex', 'individual_inferred_rh']]

comprehensive_group_summary(group)

Calculate comprehensive summary statistics for infection data.

Returns mean, median, std, min, max for continuous variables.

Source code in fpg_observational_model/unified_metric_calculations.py
def comprehensive_group_summary(group):
    """
    Calculate comprehensive summary statistics for infection data.

    Returns mean, median, std, min, max for continuous variables.
    """
    if len(group) == 0:
        return _empty_comprehensive_summary()

    # Monogenomics and polygenomics counts and COI stats
    true_coi_stats = _comprehensive_stats(group['true_coi'], 'true_coi')
    effective_coi_stats = _comprehensive_stats(group['effective_coi'], 'effective_coi')

    # Genome ID analysis
    all_genome_stats = _analyze_genome_ids(group['recursive_nid'], "all_genomes")
    mono_genome_stats = _analyze_genome_ids(group[group['effective_coi'] == 1]['recursive_nid'], "mono_genomes")

    # Cotransmission and superinfection analysis
    poly_series =  group[group['true_coi'] > 1]['cotx']
    cotxn_counts = _analyze_binary_in_subset(poly_series, 'cotransmission')

    # Combine all stats
    result = pd.Series({
        'n_infections': len(group),
        **true_coi_stats,
        **effective_coi_stats,
        **all_genome_stats,
        **mono_genome_stats,
        **cotxn_counts
    })

    if 'genotype_coi' in group.columns:
        genotype_coi_stats = _comprehensive_stats(group['genotype_coi'], 'genotype_coi')

        result = pd.concat([result, genotype_coi_stats])

    return result

generate_het_barcode(matrix, indices)

Checks for unique alleles at each locus for a specified set of genotypes identified by indices. If all alleles are the same at a locus, returns '0' or '1' for that locus. If there is a mix of alleles at a locus, returns 'N' for that locus. If no indices are provided, returns an empty list.

Note: To update for multi-allelic loci, modify the conditions within the list comprehension.

TODO: Add option to account for densities to potentially mask polygenomic samples due to low density.

Source code in fpg_observational_model/unified_metric_calculations.py
def generate_het_barcode(matrix, indices):
    """Checks for unique alleles at each locus for a specified set of genotypes identified by indices.
    If all alleles are the same at a locus, returns '0' or '1' for that locus.
    If there is a mix of alleles at a locus, returns 'N' for that locus.
    If no indices are provided, returns an empty list.

    Note: To update for multi-allelic loci, modify the conditions within the list comprehension.

    TODO: Add option to account for densities to potentially mask polygenomic samples due to low density.
    """
    if isinstance(indices, str):
        indices = ast.literal_eval(indices)

    if len(indices) == 0:
        return 0, [], []

    try:
        subset_matrix = matrix[indices, :]
        unique_rows = np.unique(subset_matrix, axis=0)

        # Check each column (locus)
        barcode = []
        for col in subset_matrix.T:  # Transpose to iterate over columns
            if np.all(col == 0):
                barcode.append('0')
            elif np.all(col == 1):
                barcode.append('1')
            else:
                barcode.append('N')

        het = calculate_heterozygosity(subset_matrix)

        return unique_rows.shape[0], barcode, het.tolist()

    except Exception as e:
        print(f"Error in generate_het_barcode: {e}")
        return 0, [], []

get_matrix(name)

Get a registered matrix

Source code in fpg_observational_model/unified_metric_calculations.py
def get_matrix(name):
    """Get a registered matrix"""
    if name in _matrix_registry:
        return _matrix_registry[name]
    else:
        available = list(_matrix_registry.keys())
        raise KeyError(f"Matrix '{name}' not found. Available: {available}")

ibx_distribution(indices, hash_ibx)

Fxn returns counts of pairwise values per group. Key: IBx value, value: counts

Source code in fpg_observational_model/unified_metric_calculations.py
def ibx_distribution(indices, hash_ibx):
    """
    Fxn returns counts of pairwise values per group.
    Key: IBx value, value: counts
    """
    max_scaler = np.max(hash_ibx)
    ibx_dict = {}
    pairwise_hash = combinations(indices, 2)
    pairwise_counts = Counter(pairwise_hash) 
    # loop through pairwise combinations
    for key in pairwise_counts:
        weight = pairwise_counts[key]
        if len(set(key)) == 1:
            ibs = 1
        else:
            a, b = key  
            ibs = hash_ibx[a, b]/max_scaler
        # update the dictionary object
        ibs = round(ibs, 2)
        if ibs in ibx_dict:
            new_value = weight + ibx_dict[ibs]
            ibx_dict.update({ibs: new_value})
        else:
            ibx_dict.update({ibs: weight})
    return ibx_dict

identify_nested_comparisons(df, sampling_column_name, config=None)

Generate a list of infections within sampling schemes for looping through nested comparisons.

Source code in fpg_observational_model/unified_metric_calculations.py
def identify_nested_comparisons(df, sampling_column_name, 
    config = None):
    """
    Generate a list of infections within sampling schemes for looping through nested comparisons. 
    """
    nested_indices = {}

    # Specifying time groups
    if 'seasonal' not in sampling_column_name:
        if 'month' in sampling_column_name:
            time_group = 'group_month'
        else:
            time_group = 'group_year'    
        nested_indices[time_group] = df.groupby(time_group)['infIndex'].apply(list).to_dict()

    if 'seasonal' in sampling_column_name: 
        time_group = sampling_column_name
        if len(df[sampling_column_name].unique()) > 1:
            nested_indices['season_bins'] = df.groupby(sampling_column_name)['infIndex'].apply(list).to_dict()
        else:
            print("User specified comparisons by season, but only one season found.") 

    if 'age' in sampling_column_name: 
        if len(df[sampling_column_name].unique()) > 1:
            nested_indices['age_bins'] = df.groupby(['group_year', sampling_column_name])['infIndex'].apply(list).to_dict()
        else:
            print("User specified nested comparisons by age bin, but only one age bin available in the sample subset.")               

    # Specifying non-time groups; i.e. subgroups 
    if 'age' not in sampling_column_name and config is not None:
        if config.get('populations', False):
            if len(df['population'].unique()) > 1:  
                nested_indices['populations'] = df.groupby([time_group, 'population'])['infIndex'].apply(list).to_dict()
            else:
                print("User specified nested comparisons by population, but only one population is available.")

        if config.get('polygenomic', False):
            df = df.copy()
            df['is_polygenomic'] = df['effective_coi'].apply(lambda x: True if x > 1 else False) 
            polygenomic_vals = df['is_polygenomic'].unique()
            if True in polygenomic_vals and False in polygenomic_vals:
                nested_indices['polygenomic'] = df.groupby([time_group, 'is_polygenomic'])['infIndex'].apply(list).to_dict()
            else:
                print("User specified nested comparisons by monogenomic or polygenomic infections, but only one group available.")   

        if config.get('symptomatic', False):  
            if len(df['fever_status'].unique()) > 1:
                nested_indices['symptomatic'] = df.groupby([time_group, 'fever_status'])['infIndex'].apply(list).to_dict()  
            else:
                print("User specified nested comparisons by fever status, but only one fever status is available.")       

        if config.get('age_bins', False):
            days_per_year = 365.25
            age_bins = [0, int(days_per_year * 5), int(days_per_year * 15), int(df['age_day'].max() + 1)]
            age_bin_labels = ['0-5yrs', '5-15yrs', '15+yrs']     
            df['age_bin'] = pd.cut(df['age_day'], bins=age_bins, labels=age_bin_labels, include_lowest=True)

            if len(df['age_bin'].unique()) > 1:
                nested_indices['age_bins'] = df.groupby([time_group, 'age_bin'],observed=True)['infIndex'].apply(list).to_dict()  
            else:
                available_age_group = df['age_bin'].unique()[0]
                print(f"User specified nested comparisons by age bins, but only one age group, {available_age_group}, is available.")        

    return nested_indices

inf_ibx_summary(ibx_matrix, ibx_indices)

Run the IBx summary for a list of genome indices for each polygenomic infection.

Source code in fpg_observational_model/unified_metric_calculations.py
def inf_ibx_summary(ibx_matrix, ibx_indices):
    """
    Run the IBx summary for a list of genome indices for each polygenomic infection.
    """
    distribution = ibx_distribution(ibx_indices, ibx_matrix)
    summary_stats = weighted_describe_scipy(distribution, "ibx")

    return summary_stats.iloc[0].to_dict()

process_nested_fws(nested_indices, sampling_df, ibs_matrix='ibs_matrix')

Calculate Fws for a single sample based on population heterozygosity. [
Mirrors logic from R package moimix used to calculate F_ws. Specifically following the logic in the function getFws() with this as the comment:

1
2
3
Compute the within host diversity statistic according to the method devised in  Manske et.al, 2012. Briefly, within sample heterozygosity and within population heterozygosity are computed and assigned to ten equal sized MAF bins [0.0.05]...[0.45,0.5]. For each bin the mean within sample and population heterozygosity is computed. A regression line of these values through the origin is computed for each sample. The \eqn{Fws} is then \eqn{1 - eta}.

Manske, Magnus, et al. "Analysis of Plasmodium falciparum diversity in natural infections by deep sequencing." Nature 487.7407 (2012): 375-379.
Source code in fpg_observational_model/unified_metric_calculations.py
def process_nested_fws(nested_indices, sampling_df, ibs_matrix = 'ibs_matrix'):
    """Calculate Fws for a single sample based on population heterozygosity.
[        
    Mirrors logic from R package moimix used to calculate F_ws. Specifically following the logic in the function getFws() with this as the comment:

    Compute the within host diversity statistic according to the method devised in  Manske et.al, 2012. Briefly, within sample heterozygosity and within population heterozygosity are computed and assigned to ten equal sized MAF bins [0.0.05]...[0.45,0.5]. For each bin the mean within sample and population heterozygosity is computed. A regression line of these values through the origin is computed for each sample. The \eqn{Fws} is then \eqn{1 - \beta}.

    Manske, Magnus, et al. "Analysis of Plasmodium falciparum diversity in natural infections by deep sequencing." Nature 487.7407 (2012): 375-379.
    """
    fws_stats_list = []

    # Get year-level indices (same as process_nested_ibx)
    if 'group_year' in nested_indices.keys():
        all_year_indices = nested_indices['group_year']
    if 'group_month' in nested_indices.keys():
        all_year_indices = nested_indices['group_month']    
    elif 'season_bins' in nested_indices.keys():
        all_year_indices = nested_indices['season_bins']
    else:
        print("Warning: No year-level indices found")
        return pd.DataFrame()

    # Process each year
    for year_key, year_indices in all_year_indices.items():
        year = str(year_key) if not isinstance(year_key, tuple) else str(year_key[0])
        year_subset = sampling_df[sampling_df['infIndex'].isin(year_indices)]

        # Get genome indices for this year
        genome_indices = []
        for idx_list in year_subset['original_nid']:
            if isinstance(idx_list, list):
                genome_indices.extend(idx_list)

        if len(genome_indices) == 0:
            continue

        # Calculate population-level heterozygosity once per year
        matrix = get_matrix(ibs_matrix)[genome_indices, :]
        group_af = np.sum(matrix == 1, axis=0) / matrix.shape[0]
        group_het = calculate_heterozygosity(matrix)
        maf_bins = pd.cut(group_af, bins=np.linspace(0, 0.5, 11), labels=False) + 1
        group_het_by_bin = pd.Series(group_het).groupby(maf_bins).mean()

        # Helper function for Fws calculation
        def calc_fws_for_sample(sample_het_list):
            sample_het = np.array(sample_het_list)
            try:
                sample_het_by_bin = pd.Series(sample_het).groupby(maf_bins).mean()
                combined = pd.DataFrame({
                    'pop_het': group_het_by_bin,
                    'sample_het': sample_het_by_bin
                }).dropna()

                if len(combined) == 0:
                    return np.nan

                X = combined['pop_het'].values.reshape(-1, 1)
                y = combined['sample_het'].values
                model = LinearRegression(fit_intercept=False)
                model.fit(X, y)
                fws = round(1 - model.coef_[0], 3)
                return fws
            except:
                return np.nan

        # Process year-level and nested groups
        for comparison_type, group_data in nested_indices.items():
            if comparison_type in ['group_year', 'group_month','season_bins']:
                # Year-level calculation
                year_subset = year_subset.copy()
                year_subset['fws'] = year_subset['heterozygosity'].apply(calc_fws_for_sample)
                valid_fws = year_subset['fws'].dropna()

                if len(valid_fws) > 0:
                    fws_summary = _comprehensive_stats(valid_fws, 'fws')
                    fws_stats_list.append({
                        'comparison_type': comparison_type,
                        'year_group': str(year) if 'group_year' in comparison_type else None,
                        'month_group': str(year) if 'group_month' in comparison_type else None,
                        'subgroup': None,
                        'allele_frequencies': np.round(group_af, 3).tolist(),
                        'heterozygosity_per_position': np.round(group_het, 3).tolist(),
                        'n_samples': len(valid_fws),
                        **fws_summary.to_dict()
                    })
            else:
                # Handle nested groups within this year
                for key, nested_data in group_data.items():
                    if isinstance(key, tuple):
                        key_year, subgroup = key
                        if str(key_year) == year:
                            if isinstance(nested_data, list) and len(nested_data) > 0:
                                subset_df = year_subset[year_subset['infIndex'].isin(nested_data)]
                                if not subset_df.empty:
                                    # Recalculate group-level stats for this subgroup
                                    subgroup_genome_indices = []
                                    for idx_list in subset_df['original_nid']:
                                        if isinstance(idx_list, list):
                                            subgroup_genome_indices.extend(idx_list)

                                    if len(subgroup_genome_indices) > 0:
                                        subgroup_matrix = get_matrix(ibs_matrix)[subgroup_genome_indices, :]
                                        subgroup_af = np.sum(subgroup_matrix == 1, axis=0) / subgroup_matrix.shape[0]
                                        subgroup_het = calculate_heterozygosity(subgroup_matrix)
                                        subgroup_maf_bins = pd.cut(subgroup_af, bins=np.linspace(0, 0.5, 11), labels=False) + 1
                                        subgroup_het_by_bin = pd.Series(subgroup_het).groupby(subgroup_maf_bins).mean()

                                        # Recalculate Fws with subgroup-specific heterozygosity
                                        def calc_fws_subgroup(sample_het_list):
                                            sample_het = np.array(sample_het_list)
                                            try:
                                                sample_het_by_bin = pd.Series(sample_het).groupby(subgroup_maf_bins).mean()
                                                combined = pd.DataFrame({
                                                    'pop_het': subgroup_het_by_bin,
                                                    'sample_het': sample_het_by_bin
                                                }).dropna()

                                                if len(combined) == 0:
                                                    return np.nan

                                                X = combined['pop_het'].values.reshape(-1, 1)
                                                y = combined['sample_het'].values
                                                model = LinearRegression(fit_intercept=False)
                                                model.fit(X, y)
                                                return round(1 - model.coef_[0], 3)
                                            except:
                                                return np.nan

                                        subset_df['fws'] = subset_df['heterozygosity'].apply(calc_fws_subgroup).copy()
                                        valid_fws = subset_df['fws'].dropna()

                                        if len(valid_fws) > 0:
                                            fws_summary = _comprehensive_stats(valid_fws, 'fws')
                                            fws_stats_list.append({
                                                'comparison_type': comparison_type,
                                                'year_group': str(year) if 'group_year' in comparison_type else None,
                                                'month_group': str(year) if 'group_month' in comparison_type else None,
                                                'subgroup': str(subgroup),
                                                'allele_frequencies': np.round(subgroup_af, 3).tolist(),
                                                'heterozygosity_per_position': np.round(subgroup_het, 3).tolist(),
                                                'n_samples': len(valid_fws),
                                                **fws_summary.to_dict()
                                            })

    if fws_stats_list:
        return pd.DataFrame(fws_stats_list)
    else:
        print("Warning: No Fws summary generated")
        return pd.DataFrame()

process_nested_ibx(df, gt_matrix, nested_indices, ibx_prefix, individual_ibx_calculation=True, save_ibx_distributions=True, save_pairwise_ibx=False)

Calculate IBx for nested comparison groups.

Parameters:

Name Type Description Default
df

DataFrame linking infection information to genotype indices

required
gt_matrix

Matrix of genotypes, roots or alleles.

required
nested_indices

Dictionary with comparison types as keys and nested group data as values

required
save_ibx_distributions

Option to save the dictionary of value with counts for the full pairwise distribution

True
individual_ibx_calculation

Whether to calculate individual IBx values

True

Returns:

Type Description

DataFrame with summary statistics for each group/subgroup

Source code in fpg_observational_model/unified_metric_calculations.py
def process_nested_ibx(df, gt_matrix, nested_indices, 
ibx_prefix,
individual_ibx_calculation=True,
save_ibx_distributions=True,
save_pairwise_ibx=False):
    """
    Calculate IBx for nested comparison groups.

    Args:
        df: DataFrame linking infection information to genotype indices
        gt_matrix: Matrix of genotypes, roots or alleles. 
        nested_indices: Dictionary with comparison types as keys and nested group data as values
        save_ibx_distributions: Option to save the dictionary of value with counts for the full pairwise distribution
        individual_ibx_calculation: Whether to calculate individual IBx values

    Returns:
        DataFrame with summary statistics for each group/subgroup
    """

    if 'group_year' in nested_indices.keys():
        all_year_indices = nested_indices['group_year']

    if 'group_month' in nested_indices.keys():
        all_year_indices = nested_indices['group_month']    

    if 'season_bins' in nested_indices.keys():
        all_year_indices = nested_indices['season_bins']

    ibx_dist_dict, individual_ibx_dict = {}, {}
    ibx_summ_list = []
    for year_key, indices in all_year_indices.items():
        # Handle both string keys and tuple keys for year
        year = str(year_key) if not isinstance(year_key, tuple) else str(year_key[0])

        year_subset = df[df['infIndex'].isin(indices)]
        year_subset = update_ibx_index(year_subset) 

        # Step 1: Run pairwise IBx calculations once per year
        genome_indices = []
        for idx_list in year_subset['recursive_nid']:
            if isinstance(idx_list, list):
                genome_indices.extend(idx_list)

        matrix = get_matrix(gt_matrix)[genome_indices, :]
        print("Genotype matrix shape:", matrix.shape)
        ibx_indices, ibx_matrix = calculate_ibx_matrix(year_subset, matrix)

        # FOR HAIRBALL connectedness plots, save the ibx_matrix and ibx_indices per year 
        if save_pairwise_ibx:
            # Update output directory here
            output_dir = "output"
            pd.save_csv(ibx_indices, f"{output_dir}/ibx_indices_{year_key}.csv", index=False)
            np.save(f"{output_dir}/ibx_matrix_{year_key}.npy", ibx_matrix)


        # Add column with the ibx_index for each infection
        ibx_mapping = dict(zip(ibx_indices['ibx_nid'], ibx_indices['ibx_index']))
        year_subset['ibx_index'] = year_subset['ibx_nid'].apply(lambda nid_list: [ibx_mapping[nid] for nid in nid_list] if isinstance(nid_list, list) else None)

        # Step 2: Run IBx summaries for nested groups within each year
        for comparison_group, group_data in nested_indices.items():
            if comparison_group in ['group_year', 'season_bins']:
                # Simple year-level calculation
                indices = list(chain.from_iterable(year_subset['ibx_index'].tolist())) 

                if isinstance(indices, list) and len(indices) > 1:
                    distribution = ibx_distribution(indices, ibx_matrix)
                    summary_stats = weighted_describe_scipy(distribution, ibx_prefix)

                    # Add metadata columns
                    result_row = summary_stats.iloc[0].to_dict()
                    result_row['comparison_type'] = comparison_group
                    result_row['year_group'] = year
                    result_row['subgroup'] = None
                    ibx_summ_list.append(result_row)

                    if save_ibx_distributions:
                        if comparison_group not in ibx_dist_dict:
                            ibx_dist_dict[comparison_group] = {}
                        # FIXED: Remove the if condition, just assign
                        ibx_dist_dict[comparison_group][year] = distribution

            else:
                # Handle nested groups
                for key, nested_data in group_data.items():
                    if isinstance(key, tuple):
                        key_year, subgroup = key
                        if str(key_year) == year:
                            if isinstance(nested_data, list) and len(nested_data) > 1:
                                subset_df = year_subset[year_subset['infIndex'].isin(nested_data)]
                                if not subset_df.empty:
                                    subgroup_ibx_indices = list(chain.from_iterable(subset_df['ibx_index'].tolist()))

                                    if subgroup_ibx_indices and len(subgroup_ibx_indices) > 1:
                                        distribution = ibx_distribution(subgroup_ibx_indices, ibx_matrix)
                                        summary_stats = weighted_describe_scipy(distribution, ibx_prefix)

                                        result_row = summary_stats.iloc[0].to_dict()
                                        result_row['comparison_type'] = comparison_group
                                        result_row['year_group'] = year
                                        result_row['subgroup'] = str(subgroup)
                                        ibx_summ_list.append(result_row)

                                        if save_ibx_distributions:
                                            if comparison_group not in ibx_dist_dict:
                                                ibx_dist_dict[comparison_group] = {}
                                            ibx_dist_dict[comparison_group][key] = distribution

        # Step 3: Individual IBx calculations for polygenomic infections
            if individual_ibx_calculation:
                polygenomic_subset = year_subset[year_subset['effective_coi'] > 1]
                if not polygenomic_subset.empty:
                    polygenomic_dict = dict(zip(polygenomic_subset['infIndex'], polygenomic_subset['ibx_index']))

                    for inf_id, ibx_list in polygenomic_dict.items():
                        distribution = ibx_distribution(ibx_list, ibx_matrix)
                        individual_ibx_dict[inf_id] = weighted_describe_scipy(distribution, ibx_prefix) 

    if ibx_summ_list:
        ibx_results_df = pd.DataFrame(ibx_summ_list)
    else:
        ibx_results_df = pd.DataFrame()

    if individual_ibx_dict:  
        individual_ibx_df = pd.concat(individual_ibx_dict, names=['infIndex', 'row_id']).reset_index(level=0)
    else:
        individual_ibx_df = pd.DataFrame()

    return ibx_results_df, individual_ibx_df, ibx_dist_dict

register_matrix(name, matrix)

Register a matrix for IBx calculations

Source code in fpg_observational_model/unified_metric_calculations.py
def register_matrix(name, matrix):
    """Register a matrix for IBx calculations"""
    _matrix_registry[name] = matrix
    print(f"Registered matrix: {name}")

sample_from_distribution(dist_dict, n_bootstraps=200, exclude_keys=[1])

Unpacks the pairwise IBS distribution dictionary to sample from the distribution n times. Excluded keys are values that do not represent the distribution of interest - e.g. IBS=1 for identical barcodes since these would not be detected as a mixed infection.

Source code in fpg_observational_model/unified_metric_calculations.py
def sample_from_distribution(dist_dict, n_bootstraps = 200, exclude_keys=[1]):
    """ Unpacks the pairwise IBS distribution dictionary to sample from the distribution n times. Excluded keys are values that do not represent the distribution of interest - e.g. IBS=1 for identical barcodes since these would not be detected as a mixed infection. """

    if exclude_keys is None:
        exclude_keys = []
    filtered_dict = {k: v for k, v in dist_dict.items() if k not in exclude_keys}

    values = list(filtered_dict.keys())
    weights = list(filtered_dict.values())

    distribution_list = np.random.choice(values, size=n_bootstraps, p=np.array(weights)/sum(weights))

    return distribution_list

update_ibx_index(filter_df)

For year specific IBX calculations, update the recursive_nid to a global order based on their unique values.

Source code in fpg_observational_model/unified_metric_calculations.py
def update_ibx_index(filter_df):
    """ 
    For year specific IBX calculations, update the recursive_nid to a global order based on their unique values.
    """
    # Step 1: Get all unique recursive_nid values across all rows
    filter_df = filter_df.copy()
    all_nids = []
    for nid_list in filter_df['recursive_nid']:
        all_nids.extend(nid_list)

    # Get unique values and sort them
    unique_nids = sorted(set(all_nids))

    # Step 2: Create a mapping from nid to its global order
    nid_to_order = {nid: i for i, nid in enumerate(unique_nids)}

    # Step 3: Apply the mapping to create the order column
    def map_to_global_order(nid_list):
        return [nid_to_order[nid] for nid in nid_list]

    filter_df['ibx_nid'] = filter_df['recursive_nid'].apply(map_to_global_order)

    return filter_df

weighted_describe_scipy(summary_dict, ibx_prefix)

Calculate stats by expanding the weighted dictionary

Source code in fpg_observational_model/unified_metric_calculations.py
def weighted_describe_scipy(summary_dict, ibx_prefix):
    """Calculate stats by expanding the weighted dictionary"""
    if not summary_dict:
        return pd.DataFrame()

    # Expand the dictionary to a list
    expanded_values = []
    for value, count in summary_dict.items():
        expanded_values.extend([value] * int(count))

    expanded_values = np.array(expanded_values)

    # Now use standard numpy/pandas functions
    count = len(expanded_values)
    mean = np.mean(expanded_values)
    std = np.std(expanded_values, ddof=1)  # Use ddof=1 for sample std to match pandas

    summary_data = {
        f'{ibx_prefix}_count': int(count),
        f'{ibx_prefix}_mean': round(mean, 3),
        f'{ibx_prefix}_std': round(std, 3),
        f'{ibx_prefix}_min': round(np.min(expanded_values), 3),
        f'{ibx_prefix}_25%': round(np.percentile(expanded_values, 25), 3),
        f'{ibx_prefix}_50%': round(np.median(expanded_values), 3),
        f'{ibx_prefix}_75%': round(np.percentile(expanded_values, 75), 3),
        f'{ibx_prefix}_max': round(np.max(expanded_values), 3)
    }

    return pd.DataFrame([summary_data])