Shortcuts

Source code for ignite.metrics.rec_sys.mrr

from collections.abc import Callable

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["MRR"]


[docs]class MRR(Metric): r"""Calculates the Mean Reciprocal Rank (MRR) at `k` for Recommendation Systems. MRR measures the average reciprocal rank of the first relevant item in the predicted ranking for each user. The reciprocal rank for a user is :math:`1/\text{rank}` where :math:`\text{rank}` is the position of the first relevant item in the ranked list (1-indexed). Users for which no relevant item appears in the top-k results contribute 0 to the score. .. math:: \text{MRR}@K = \frac{1}{N} \sum_{i=1}^{N} \frac{1}{\text{rank}_i} where :math:`\text{rank}_i \in \{1, 2, \ldots, K\}` is the rank of the first relevant item for user :math:`i` in the top-k predictions, and is treated as :math:`\infty` (yielding a reciprocal rank of 0) if no relevant item is in the top-k predictions. - ``update`` must receive output of the form ``(y_pred, y)``. - ``y_pred`` is expected to be raw logits or probability scores for each item in the catalog. - ``y`` is expected to be binary (only 0s and 1s) values where ``1`` indicates a relevant item. - ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`. - returns a list of MRR values ordered by the sorted values of ``top_k``. Args: top_k: a single positive integer or a list of positive integers that specifies ``k`` for calculating MRR@top-k. If a single int is provided, it will be wrapped in a list. Default is 10. ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros) are ignored in computation of MRR. If set False, such users are counted with a reciprocal rank of 0. By default, True. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. The output is expected to be a tuple ``(prediction, target)`` where ``prediction`` and ``target`` are tensors of shape ``(batch, num_items)``. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. skip_unrolling: specifies whether input should be unrolled or not before being processed. Should be true for multi-output models. Examples: To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. The output of the engine's ``process_function`` needs to be in the format of ``(y_pred, y)``. If not, ``output_transform`` can be added to the metric to transform the output into the form expected by the metric. For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. .. include:: defaults.rst :start-after: :orphan: .. testcode:: metric = MRR(top_k=[1, 2, 3, 4]) metric.attach(default_evaluator, "mrr") y_pred = torch.Tensor([ [4.0, 2.0, 3.0, 1.0], [1.0, 2.0, 3.0, 4.0], ]) y_true = torch.Tensor([ [0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0], ]) state = default_evaluator.run([(y_pred, y_true)]) print(state.metrics["mrr"]) .. testoutput:: [0.5, 0.75, 0.75, 0.75] .. versionadded:: 0.6.0 """ required_output_keys = ("y_pred", "y") _state_dict_all_req_keys = ("_sum_rr_per_k", "_num_examples") def __init__( self, top_k: list[int] | int = 10, ignore_zero_hits: bool = True, output_transform: Callable = lambda x: x, device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, ): if not isinstance(top_k, (int, list)): raise ValueError("top_k must be either int or a list[int]") top_k = [top_k] if isinstance(top_k, int) else top_k if len(top_k) == 0: raise ValueError("top_k must have at least one positive value") if any(k <= 0 for k in top_k): raise ValueError("top_k must be list of positive integers only.") self.top_k = sorted(top_k) self.ignore_zero_hits = ignore_zero_hits super().__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
[docs] @reinit__is_reduced def reset(self) -> None: self._sum_rr_per_k = torch.zeros(len(self.top_k), device=self._device) self._num_examples = 0
[docs] @reinit__is_reduced def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: if len(output) != 2: raise ValueError(f"output should be in format `(y_pred,y)` but got tuple of {len(output)} tensors.") y_pred, y = output if y_pred.shape != y.shape: raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.") if self.ignore_zero_hits: valid_mask = torch.any(y > 0, dim=-1) y_pred = y_pred[valid_mask] y = y[valid_mask] if y.shape[0] == 0: return max_k = self.top_k[-1] # Get top-max_k predictions ordered by descending score. _, indices = torch.topk(y_pred, k=max_k, dim=-1) # Gather corresponding relevance labels in ranked order. ranked_relevance = torch.gather(y, dim=-1, index=indices) batch_size = y.shape[0] for i, k in enumerate(self.top_k): top_k_relevance = ranked_relevance[:, :k] # First-relevant position per user (1-indexed). For users with no # relevant item in top-k, fall back to a sentinel position whose # reciprocal rank evaluates to 0. has_hit = torch.any(top_k_relevance > 0, dim=-1) # argmax returns the index of the first max; for binary labels that # is the position of the first 1 when at least one exists. first_pos = torch.argmax(top_k_relevance.to(torch.long), dim=-1) + 1 reciprocal_rank = torch.where( has_hit, 1.0 / first_pos.to(torch.float32), torch.zeros(batch_size, device=top_k_relevance.device), ) self._sum_rr_per_k[i] += reciprocal_rank.sum().to(self._device) self._num_examples += y.shape[0]
[docs] @sync_all_reduce("_sum_rr_per_k", "_num_examples") def compute(self) -> list[float]: if self._num_examples == 0: raise NotComputableError("MRR must have at least one example.") rates = (self._sum_rr_per_k / self._num_examples).tolist() return rates

© Copyright 2026, PyTorch-Ignite Contributors. Last updated on 04/25/2026, 11:55:08 PM.

Built with Sphinx using a theme provided by Read the Docs.
×

Search Docs