Shortcuts

Source code for ignite.metrics.rec_sys.map

from collections.abc import Callable

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["MAP"]


[docs]class MAP(Metric): r"""Calculates the Mean Average Precision (MAP) at `k` for Recommendation Systems. MAP measures the mean of Average Precision (AP) across all users. AP for a single user is the average of precision values computed at every position where a relevant item appears in the ranked top-k list, divided by the total number of relevant items for that user (clipped at ``k``). .. math:: \text{AP}@K_i = \frac{1}{\min(R_i, K)} \sum_{j=1}^{K} \text{Precision}@j \cdot \mathbb{1}(\text{rel}_{i,j}) .. math:: \text{MAP}@K = \frac{1}{N} \sum_{i=1}^{N} \text{AP}@K_i where :math:`R_i` is the number of relevant items for user :math:`i`, :math:`\text{rel}_{i,j}` is 1 if the item at rank :math:`j` is relevant and 0 otherwise, and :math:`\text{Precision}@j` is the proportion of relevant items in the top :math:`j` ranked predictions. - ``update`` must receive output of the form ``(y_pred, y)``. - ``y_pred`` is expected to be raw logits or probability scores for each item in the catalog. - ``y`` is expected to be binary (only 0s and 1s) values where ``1`` indicates a relevant item. - ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`. - returns a list of MAP values ordered by the sorted values of ``top_k``. Args: top_k: a single positive integer or a list of positive integers that specifies ``k`` for calculating MAP@top-k. If a single int is provided, it will be wrapped in a list. Default is 10. ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros) are ignored in computation of MAP. If set False, such users are counted with an Average Precision of 0. By default, True. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. The output is expected to be a tuple ``(prediction, target)`` where ``prediction`` and ``target`` are tensors of shape ``(batch, num_items)``. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. skip_unrolling: specifies whether input should be unrolled or not before being processed. Should be true for multi-output models. Examples: To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. The output of the engine's ``process_function`` needs to be in the format of ``(y_pred, y)``. If not, ``output_transform`` can be added to the metric to transform the output into the form expected by the metric. For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. .. include:: defaults.rst :start-after: :orphan: .. testcode:: metric = MAP(top_k=[1, 2, 3, 4]) metric.attach(default_evaluator, "map") y_pred = torch.Tensor([ [4.0, 2.0, 3.0, 1.0], [1.0, 2.0, 3.0, 4.0], ]) y_true = torch.Tensor([ [0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0], ]) state = default_evaluator.run([(y_pred, y_true)]) print(state.metrics["map"]) .. versionadded:: 0.6.0 """ required_output_keys = ("y_pred", "y") _state_dict_all_req_keys = ("_sum_ap_per_k", "_num_examples") def __init__( self, top_k: list[int] | int = 10, ignore_zero_hits: bool = True, output_transform: Callable = lambda x: x, device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, ): if not isinstance(top_k, (int, list)): raise ValueError("top_k must be either int or a list[int]") top_k = [top_k] if isinstance(top_k, int) else top_k if len(top_k) == 0: raise ValueError("top_k must have at least one positive value") if any(k <= 0 for k in top_k): raise ValueError("top_k must be list of positive integers only.") self.top_k = sorted(top_k) self.ignore_zero_hits = ignore_zero_hits super().__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
[docs] @reinit__is_reduced def reset(self) -> None: self._sum_ap_per_k = torch.zeros(len(self.top_k), device=self._device) self._num_examples = 0
[docs] @reinit__is_reduced def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: if len(output) != 2: raise ValueError(f"output should be in format `(y_pred,y)` but got tuple of {len(output)} tensors.") y_pred, y = output if y_pred.shape != y.shape: raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.") if self.ignore_zero_hits: valid_mask = torch.any(y > 0, dim=-1) y_pred = y_pred[valid_mask] y = y[valid_mask] if y.shape[0] == 0: return max_k = self.top_k[-1] _, indices = torch.topk(y_pred, k=max_k, dim=-1) ranked_relevance = torch.gather(y, dim=-1, index=indices).to(torch.float32) # Total number of relevant items per user across the catalog (used as # the AP denominator, clipped at k below). total_relevant = (y > 0).to(torch.float32).sum(dim=-1) for i, k in enumerate(self.top_k): top_k_relevance = ranked_relevance[:, :k] # Cumulative number of relevant items at each rank up to k. cumulative_hits = torch.cumsum(top_k_relevance, dim=-1) positions = torch.arange(1, k + 1, dtype=torch.float32, device=top_k_relevance.device) # Precision@j evaluated at every rank j in [1, k]. precision_at_j = cumulative_hits / positions # Sum precision values only at positions where the item is relevant. sum_precision = (precision_at_j * top_k_relevance).sum(dim=-1) denom = torch.clamp(total_relevant, max=float(k)) ap_k = torch.where( denom > 0, sum_precision / denom, torch.zeros_like(sum_precision), ) self._sum_ap_per_k[i] += ap_k.sum().to(self._device) self._num_examples += y.shape[0]
[docs] @sync_all_reduce("_sum_ap_per_k", "_num_examples") def compute(self) -> list[float]: if self._num_examples == 0: raise NotComputableError("MAP must have at least one example.") rates = (self._sum_ap_per_k / self._num_examples).tolist() return rates

© Copyright 2026, PyTorch-Ignite Contributors. Last updated on 04/25/2026, 11:55:08 PM.

Built with Sphinx using a theme provided by Read the Docs.
×

Search Docs