#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2019/10/30 15:04 # @Author : Scheaven # @File : nn_matching.py # @description: import numpy as np def _pdist(a, b): a, b = np.asarray(a), np.asarray(b) if len(a) == 0 or len(b) == 0: return np.zeros((len(a), len(b))) a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1) r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :] r2 = np.clip(r2, 0., float(np.inf)) return r2 def _cosine_distance(a, b, data_is_normalized=False): if not data_is_normalized: a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True) b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True) return 1. - np.dot(a, b.T) def _nn_euclidean_distance(x, y): distances = _pdist(x, y) return np.maximum(0.0, distances.min(axis=0)) def _nn_cosine_distance(x, y): distances = _cosine_distance(x, y) return distances.min(axis=0) class NearestNeighborDistanceMetric(object): def __init__(self, metric, matching_threshold, budget=None): if metric == "euclidean": self._metric = _nn_euclidean_distance elif metric == "cosine": self._metric = _nn_cosine_distance else: raise ValueError( "Invalid metric; must be either 'euclidean' or 'cosine'") self.matching_threshold = matching_threshold self.budget = budget self.samples = {} # 可以加载离线库 def partial_fit(self, features, targets, active_targets): for feature, target in zip(features, targets): self.samples.setdefault(target, []).append(feature) if self.budget is not None: self.samples[target] = self.samples[target][-self.budget:] self.samples = {k: self.samples[k] for k in active_targets} def distance(self, features, targets): cost_matrix = np.zeros((len(targets), len(features))) for i, target in enumerate(targets): cost_matrix[i, :] = self._metric(self.samples[target], features) return cost_matrix