import os
import time
from cafaeval.graph import Graph, Prediction, GroundTruth, propagate
import numpy as np
import logging
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
_parser_logger = logging.getLogger("cafaeval.parser")
_parser_logger.addHandler(logging.NullHandler())
# import xml.etree.ElementTree as ET
[docs]
def obo_parser(obo_file, valid_rel=("is_a", "part_of"), ia_file=None, orphans=True):
"""
Parse a OBO file and returns a list of ontologies, one for each namespace.
Obsolete terms are excluded as well as external namespaces.
"""
# Parse the OBO file and creates a different graph for each namespace
term_dict = {}
term_id = None
namespace = None
name = None
term_def = None
alt_id = []
rel = []
obsolete = True
with open(obo_file) as f:
for line in f:
line = line.strip().split(": ")
if line and len(line) > 1:
k = line[0]
v = ": ".join(line[1:])
if k == "id":
# Populate the dictionary with the previous entry
if term_id is not None and obsolete is False and namespace is not None:
term_dict.setdefault(namespace, {})[term_id] = {'name': name,
'namespace': namespace,
'def': term_def,
'alt_id': alt_id,
'rel': rel}
# Assign current term ID
term_id = v
# Reset optional fields
alt_id = []
rel = []
obsolete = False
namespace = None
elif k == "alt_id":
alt_id.append(v)
elif k == "name":
name = v
elif k == "namespace" and v != 'external':
namespace = v
elif k == "def":
term_def = v
elif k == 'is_obsolete':
obsolete = True
elif k == "is_a" and k in valid_rel:
s = v.split('!')[0].strip()
rel.append(s)
elif k == "relationship" and v.startswith("part_of") and "part_of" in valid_rel:
s = v.split()[1].strip()
rel.append(s)
# Last record
if obsolete is False and namespace is not None:
term_dict.setdefault(namespace, {})[term_id] = {'name': name,
'namespace': namespace,
'def': term_def,
'alt_id': alt_id,
'rel': rel}
# Parse IA file
ia_dict = None
if ia_file is not None:
ia_dict = ia_parser(ia_file)
ontologies = {}
for ns, ont_dict in term_dict.items():
ontologies[ns] = Graph(ns, ont_dict, ia_dict, orphans)
return ontologies
[docs]
def update_toi(ontologies, toi_file):
"""
Remove terms not of interest from evaluation, eg for terms obsoleted since ontology was created
:param ontologies: dict returned from obo_parser
:param term_file: file with GO IDs to include in the terms of interest
:return: copy of ontologies with updated toi
"""
# load file of terms
new_toi = {ns: [] for ns in ontologies.keys()}
with open(toi_file) as f:
for line in f:
line = line.strip().split()
if line:
term = line[0]
for ns in ontologies.keys():
if term in ontologies[ns].terms_dict.keys():
new_toi[ns].append(ontologies[ns].terms_dict[term]['index'])
# catch alt IDs if used
elif term in ontologies[ns].terms_dict_alt.keys():
alt_ids = ontologies[ns].terms_dict_alt[term]
for alt_id in alt_ids:
new_toi[ns].append(ontologies[ns].terms_dict[alt_id]['index'])
# take intersection to make sure roots are excluded if needed
for ns in ontologies.keys():
ontologies[ns].toi = np.array(list(set(new_toi[ns]).intersection(ontologies[ns].toi)))
# toi_ia is the non-zero IA terms. We need to remove any terms not in the TOI file from there too
for ns in ontologies.keys():
if ontologies[ns].toi_ia is not None:
ontologies[ns].toi_ia = np.array(list(set(new_toi[ns]).intersection(ontologies[ns].toi_ia)))
return ontologies
[docs]
def gt_parser(gt_file, ontologies):
"""
Parse ground truth file. Discard terms not included in the ontology.
"""
gt_dict = {}
replaced = {}
with open(gt_file) as f:
for line in f:
line = line.strip().split()
if line:
p_id, term_id = line[:2]
for ns in ontologies:
if term_id in ontologies[ns].terms_dict:
gt_dict.setdefault(ns, {}).setdefault(p_id, []).append(term_id)
break
# Replace alternative ids with canonical ids
elif term_id in ontologies[ns].terms_dict_alt:
for t_id in ontologies[ns].terms_dict_alt[term_id]:
gt_dict.setdefault(ns, {}).setdefault(p_id, []).append(t_id)
replaced.setdefault(ns, 0)
replaced[ns] += 1
break
gts = {}
for ns in ontologies:
if gt_dict.get(ns):
ont = ontologies[ns]
matrix = np.zeros((len(gt_dict[ns]), ont.idxs), dtype='bool')
ids = {}
# Collect the non-zero coordinates as we fill so propagate() can
# skip the dense ``np.nonzero`` scan (the bool GT matrix has a
# density around 1e-4 on real corpora, so scanning every cell
# dominates).
nnz_est = sum(len(v) for v in gt_dict[ns].values())
nz_rows = np.empty(nnz_est, dtype=np.int64)
nz_cols = np.empty(nnz_est, dtype=np.int64)
k = 0
terms_dict = ont.terms_dict
for i, p_id in enumerate(gt_dict[ns]):
ids[p_id] = i
for term_id in gt_dict[ns][p_id]:
col = terms_dict[term_id]['index']
matrix[i, col] = 1
nz_rows[k] = i
nz_cols[k] = col
k += 1
nz_rows = nz_rows[:k]
nz_cols = nz_cols[:k]
nz_scores = np.ones(k, dtype=matrix.dtype)
logger.debug("gt matrix {} {} ".format(ns, matrix))
propagate(matrix, ont, ont.order, mode='max',
_triples=(nz_rows, nz_cols, nz_scores))
logger.debug("gt matrix propagated {} {} ".format(ns, matrix))
gts[ns] = GroundTruth(ids, matrix, ns)
logger.debug('Ground truth: {}, proteins {}, annotations {}, replaced alt. ids {}'.format(ns, len(ids),
np.count_nonzero(matrix), replaced.get(ns, 0)))
return gts
[docs]
def gt_exclude_parser(exclude_file, gt, ontologies):
"""
Process terms that should be excluded from evaluation.
"""
# Propagate exclude terms and parse alternative IDs
exclude_gt = gt_parser(exclude_file, ontologies)
# reindex exclusion matrices to match ground truth
exclude = {}
for ns in gt:
exclude_matrix = np.zeros_like(gt[ns].matrix)
for protein, gt_index in gt[ns].ids.items():
# Keep row corresponding to gt proteins
if protein in exclude_gt[ns].ids:
exclude_matrix[gt_index, :] = exclude_gt[ns].matrix[exclude_gt[ns].ids[protein], :]
exclude[ns] = GroundTruth(gt[ns].ids, exclude_matrix, ns)
return exclude
def _pred_parser_legacy(pred_file, ontologies, gts, ns_dict, term_index, ids,
matrix, row_nnz, replaced, max_terms):
"""Original line-by-line parser.
Retained as the fallback for ``max_terms``-capped parses (the cap is
order-sensitive so the vectorised path cannot reproduce it) and for
any pathological input the fast path refuses to handle.
"""
with open(pred_file, buffering=1024 * 1024) as f:
for line in f:
line = line.strip().split()
if line and len(line) > 2:
p_id, term_id, prob = line[:3]
ns = ns_dict.get(term_id)
if ns in gts and p_id in gts[ns].ids:
i = gts[ns].ids[p_id]
term_ids = [term_id]
if term_id in ontologies[ns].terms_dict_alt:
term_ids = ontologies[ns].terms_dict_alt[term_id]
replaced.setdefault(ns, 0)
replaced[ns] += len(term_ids)
for term_id in term_ids:
j = term_index[ns].get(term_id)
old = matrix[ns][i, j]
if max_terms is not None and old == 0.0 and row_nnz[ns][i] > max_terms:
continue
prob_f = float(prob)
if prob_f > old:
ids[ns][p_id] = i
matrix[ns][i, j] = prob_f
if old == 0.0:
row_nnz[ns][i] += 1
def _pred_parser_vectorised(pred_file, ontologies, gts, ns_dict, term_index,
ids, matrix, replaced):
"""PyArrow-backed bulk parser for the common case (no ``max_terms`` cap).
Strategy:
1. ``pyarrow.csv.read_csv`` reads the whole file at native speed
(~10x pandas C engine on CAFA-shaped predictions).
2. ``pid`` / ``tid`` string columns are dictionary-encoded once, so
Python-level dict lookups run over the (small) unique-value sets
instead of over every one of the millions of rows.
3. The resulting numpy int code arrays are filtered per namespace
with vectorised comparisons; per-namespace reductions use the same
sort + ``np.maximum.reduceat`` group-max + scatter as the sparse
propagation kernel.
Raises on any format surprise (non-tsv, wrong column count) so the
caller can fall back to :func:`_pred_parser_legacy`.
"""
import pyarrow as pa # lazy: pyarrow is an optional speed dependency
import pyarrow.csv as pc
tbl = pc.read_csv(
pred_file,
read_options=pc.ReadOptions(column_names=["pid", "tid", "prob"]),
parse_options=pc.ParseOptions(delimiter="\t"),
convert_options=pc.ConvertOptions(column_types={
"pid": pa.string(),
"tid": pa.string(),
"prob": pa.float64(),
}),
)
if tbl.num_rows == 0:
return
pids_dict = tbl.column("pid").combine_chunks().dictionary_encode()
tids_dict = tbl.column("tid").combine_chunks().dictionary_encode()
probs_arr = tbl.column("prob").combine_chunks().to_numpy(zero_copy_only=True)
pid_unique = pids_dict.dictionary.to_pylist()
tid_unique = tids_dict.dictionary.to_pylist()
pid_codes = pids_dict.indices.to_numpy(zero_copy_only=False)
tid_codes = tids_dict.indices.to_numpy(zero_copy_only=False)
ns_list = list(gts.keys())
ns_to_idx = {ns: i for i, ns in enumerate(ns_list)}
# Per unique tid: which namespace does it belong to (-1 if none).
tid_ns_code = np.full(len(tid_unique), -1, dtype=np.int8)
for code, tid in enumerate(tid_unique):
ns = ns_dict.get(tid)
if ns is not None and ns in ns_to_idx:
tid_ns_code[code] = ns_to_idx[ns]
# Per unique tid, per namespace: canonical column index (-1 if absent).
# Also flag whether the tid is an alt id in that namespace (stored as
# -2 so alt handling can be triggered without another Python lookup).
tid_col_per_ns = []
alt_tid_lookup_per_ns = []
for ns in ns_list:
tmap = term_index[ns]
alt_dict = ontologies[ns].terms_dict_alt
col_arr = np.full(len(tid_unique), -1, dtype=np.int64)
alt_codes = []
for code, tid in enumerate(tid_unique):
canon_col = tmap.get(tid)
if canon_col is not None:
col_arr[code] = canon_col
elif alt_dict and tid in alt_dict:
col_arr[code] = -2
alt_codes.append(code)
tid_col_per_ns.append(col_arr)
alt_tid_lookup_per_ns.append(alt_codes)
# Per unique pid, per namespace: row index (-1 if not in this GT set).
pid_row_per_ns = []
for ns in ns_list:
gt_ids = gts[ns].ids
row_arr = np.full(len(pid_unique), -1, dtype=np.int64)
for code, pid in enumerate(pid_unique):
v = gt_ids.get(pid)
if v is not None:
row_arr[code] = v
pid_row_per_ns.append(row_arr)
row_ns = tid_ns_code[tid_codes] # int8, same length as the full table
for ns_idx, ns in enumerate(ns_list):
ns_mask = row_ns == ns_idx
if not ns_mask.any():
continue
pid_codes_ns = pid_codes[ns_mask]
tid_codes_ns = tid_codes[ns_mask]
probs_ns = probs_arr[ns_mask]
# Protein filter: only keep rows whose pid is in this GT.
p_idx_all = pid_row_per_ns[ns_idx][pid_codes_ns]
in_gt = p_idx_all >= 0
if not in_gt.any():
continue
p_idx_all = p_idx_all[in_gt]
tid_codes_ns = tid_codes_ns[in_gt]
probs_ns = probs_ns[in_gt]
# Column resolution: canonical terms, alt ids (value -2), misses (-1).
col_lookup = tid_col_per_ns[ns_idx]
col_all = col_lookup[tid_codes_ns]
canonical_mask = col_all >= 0
alt_mask = col_all == -2
p_idx_parts = [p_idx_all[canonical_mask]]
t_idx_parts = [col_all[canonical_mask]]
v_parts = [probs_ns[canonical_mask]]
if alt_mask.any():
# Alt-id expansion: each alt id may map to a set of canonical
# term ids. Loop over the (small) alt subset only.
alt_dict = ontologies[ns].terms_dict_alt
ns_term_index = term_index[ns]
alt_pos = np.flatnonzero(alt_mask)
exp_p = []
exp_t = []
exp_v = []
for k in alt_pos:
tid_str = tid_unique[tid_codes_ns[k]]
canon_set = alt_dict.get(tid_str)
if not canon_set:
continue
replaced[ns] = replaced.get(ns, 0) + len(canon_set)
for canon in canon_set:
col = ns_term_index.get(canon)
if col is None:
continue
exp_p.append(int(p_idx_all[k]))
exp_t.append(col)
exp_v.append(float(probs_ns[k]))
if exp_p:
p_idx_parts.append(np.asarray(exp_p, dtype=np.int64))
t_idx_parts.append(np.asarray(exp_t, dtype=np.int64))
v_parts.append(np.asarray(exp_v, dtype=np.float64))
p_final = np.concatenate(p_idx_parts) if len(p_idx_parts) > 1 else p_idx_parts[0]
t_final = np.concatenate(t_idx_parts) if len(t_idx_parts) > 1 else t_idx_parts[0]
v_final = np.concatenate(v_parts) if len(v_parts) > 1 else v_parts[0]
if p_final.size == 0:
continue
# Sort-based group-max over (row, col). Single int64 flat key, one
# stable argsort, np.maximum.reduceat finishes the reduction.
n_terms = matrix[ns].shape[1]
flat = p_final * np.int64(n_terms) + t_final
order = np.argsort(flat, kind="stable")
flat_s = flat[order]
v_s = v_final[order]
starts = np.empty(flat_s.size, dtype=bool)
starts[0] = True
np.not_equal(flat_s[1:], flat_s[:-1], out=starts[1:])
start_idx = np.flatnonzero(starts)
max_v = np.maximum.reduceat(v_s, start_idx)
unique_flat = flat_s[start_idx]
unique_rows = unique_flat // np.int64(n_terms)
unique_cols = unique_flat % np.int64(n_terms)
# Scatter into the namespace matrix. The matrix starts at zero, so
# direct assignment is equivalent to the legacy
# "if prob > old: matrix[i, j] = prob" write rule.
matrix[ns][unique_rows, unique_cols] = max_v
# Register proteins that contributed at least one non-zero value.
prot_has_any = np.zeros(matrix[ns].shape[0], dtype=bool)
prot_has_any[unique_rows[max_v > 0.0]] = True
surviving_rows = np.flatnonzero(prot_has_any)
if surviving_rows.size:
gt_ids = gts[ns].ids
inv_ids = {int(v): k for k, v in gt_ids.items()}
for r in surviving_rows.tolist():
ids[ns][inv_ids[r]] = r
[docs]
def pred_parser(pred_file, ontologies, gts, prop_mode, max_terms=None, n_cpu=0):
"""
Parse a prediction file and returns a list of prediction objects, one for each namespace.
If a predicted is predicted multiple times for the same target, it stores the max.
This is the slow step if the input file is huge, ca. 1 minute for 5GB input on SSD disk.
"""
ids = {}
matrix = {}
ns_dict = {} # {namespace: term}
replaced = {}
row_nnz = {}
term_index = {}
for ns in gts:
matrix[ns] = np.zeros(gts[ns].matrix.shape, dtype='float')
row_nnz[ns] = np.zeros(gts[ns].matrix.shape[0], dtype=np.int32)
ids[ns] = {}
term_index[ns] = {t: info["index"] for t, info in ontologies[ns].terms_dict.items()}
for term in ontologies[ns].terms_dict:
ns_dict[term] = ns
for term in ontologies[ns].terms_dict_alt:
ns_dict[term] = ns
t0 = time.time()
fast_path = (
max_terms is None
and os.environ.get("CAFAEVAL_FAST_PARSER", "1") not in ("0", "false", "False")
)
used_fast_path = False
if fast_path:
try:
_pred_parser_vectorised(
pred_file, ontologies, gts, ns_dict, term_index,
ids, matrix, replaced,
)
used_fast_path = True
except Exception as exc:
_parser_logger.warning(
"pred_parser fast path failed, falling back to legacy loop",
extra={"file": pred_file, "error": repr(exc)},
)
# Reset any partial state the fast path may have written.
for ns in gts:
matrix[ns].fill(0)
ids[ns].clear()
replaced.clear()
if not used_fast_path:
_pred_parser_legacy(
pred_file, ontologies, gts, ns_dict, term_index,
ids, matrix, row_nnz, replaced, max_terms,
)
t1 = time.time()
path_label = "vectorised" if used_fast_path else "legacy"
_parser_logger.info(
f"pred_parser read: {os.path.basename(pred_file)} ({path_label}) "
f"in {t1 - t0:.2f}s",
extra={"file": pred_file, "seconds": round(t1 - t0, 3),
"path": path_label},
)
predictions = {}
tp0 = time.time()
for ns in ids:
if ids[ns]:
logger.debug("pred matrix {} {} ".format(ns, matrix))
t_ns = time.time()
propagate(matrix[ns], ontologies[ns], ontologies[ns].order, mode=prop_mode, parallel=n_cpu)
_parser_logger.info(
f"pred_parser propagated {ns:>18s}: "
f"{len(ids[ns])} proteins, {int(np.count_nonzero(matrix[ns]))} annots "
f"({time.time() - t_ns:.2f}s)",
extra={"file": pred_file, "ns": ns,
"proteins": int(len(ids[ns])),
"annotations": int(np.count_nonzero(matrix[ns])),
"seconds": round(time.time() - t_ns, 3)},
)
logger.debug("pred matrix {} {} ".format(ns, matrix))
predictions[ns] = Prediction(ids[ns], matrix[ns], ns)
_parser_logger.info(
f"pred_parser total: read+propagate in {t1 - t0 + (time.time() - tp0):.2f}s",
extra={"file": pred_file, "seconds": round(time.time() - tp0, 3)},
)
if not predictions:
logger.warning("Empty prediction! Check format or overlap with ground truth")
return predictions
[docs]
def ia_parser(file):
ia_dict = {}
with open(file) as f:
for line in f:
if line:
term, ia = line.strip().split()
ia_dict[term] = float(ia)
return ia_dict