U
    md'2                    @   s<  d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dl	m
Z
 d dl	mZ d dlmZ d dlmZ d dlmZ d d	lmZmZ d d
lmZ d dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d dlm Z  d dlm!Z! d dlm"Z" d dlm#Z# d dlm$Z$ d dlm%Z% d dl&m'Z'm(Z( d dlm)Z)m*Z* d dlm+Z+ d dl,m-Z- d dl.m/Z/ d dl0m1Z1 d dl2m3Z3 e e"e%gZ4dd!d"Z5d#d$ Z6d%d& Z7d'd( Z8d)d* Z9ej:;d+d,d gd-d. Z<d/d0 Z=d1d2 Z>d3d4 Z?d5d6 Z@d7d8 ZAd9d: ZBd;d< ZCd=d> ZDd?d@ ZEdAdB ZFdCdD ZGej:;dEeHd dFd dGgd dFdGgfeHd dFd dGgdfdHdIdHdJgdHdIdJgfdHdIdHdJgdfgdKdL ZIej:;dEeHd dGd dGgd dFdGgfeHdHdMdHdMgdHdIdMgfgdNdO ZJej:;dEeHd dFdGdGgdfdHdIdJdJgdfd dFdGdGgd dFdGgfdHdIdJdJgdHdIdJgfgdPdQ ZKej:;dRdSdTdUgdVdW ZLdXdY ZMej:;dZd[eHd dFdGdGgd dGd gfd[eHdHdIdJdJgdHdHdIgfd\eHd dGd dGgdfd]eHdHdIdJdJgdHdJdIgfd^eHd dFdGdGgd dFgfd^eHdHdIdJdJgdHdIgfd_eHd dFdGdGgd dFdGd`gfd_eHdHdIdJdJgdHdIdJdMgfdaeHdHdIdJdbgdHdIdJgfdaeHdHdIdJdMgdHdIdJgfdaeHd dFdGd`gd dFdGgfgej:;dcdddegdfdg ZNej:;dhdidjdddkfdldjdedkfdmddg dnfdodddpdqfdrdcdsifdti fgdudv ZOdwdx ZPej:;dye4dzd{ ZQej:;dye4d|d} ZRej:;dye4d~d ZSdd ZTdd ZUdd ZVdd ZWdd ZXdd ZYej:;dd d dFgd dpdFgd gd gfd d dFgd ddpgd gd gfd d dFgdpddFgd gd gfd d dFgddpdgd gd gfd dFd gd dpdFgdpgd gfd dFd gd ddpgdpgd gfd dFd gdpddFgdpgd gfd dFd gddpdgdpgd gfd dFdFgd dpdFgdgd gfd dFdFgd ddpgd gd gfd dFdFgdpddFgd gd gfd dFdFgddpdgd gd gfdFd d gd dpdFgdFdFdpgd dFdFgfdFd d gd ddpgdFdFdpgd dFdFgfdFd d gdpddFgdFdFdpgd dFdFgfdFd d gddpdgdFdFdpgd dFdFgfdFd dFgd dpdFgdFdFd gd dpdpgfdFd dFgd ddpgdFdFd gd dpdpgfdFd dFgdpddFgdFdFd gd dpdpgfdFd dFgddpdgdFdFd gd dpdpgfgdd ZZej:;ddFd gdpdpgdFgd gfd dFgdpdpgdFgd gfd d dFgddpdpgdpgd gfd dFd gddpdpgdpgd gfd dFdFgddpdpgd gd gfdFd d gddpdpgdFgd gfdFd dFgddpdpgdFgd gfdFdFd gddpdpgdFgd gfgdd Z[dd Z\ej:;dd ddpddFgdd Z]ej:;dd d d d d dFgd d d d dFdFgd d d dFdFdFgd d dFdFdFdFgd dFdFdFdFdFggdd Z^ej:;dd dFgd dpdFgdfd dFdFgd dpgdfd d d gd dpdFgdfdFdFdFgd dpdFgdfdddgdddgdfgdd Z_dd Z`dd Zadd Zbdd Zcdd Zddd Zedd ZfdddZgej:;deaeeedebfej:;de!effdd Zhdd Ziej:;ddej:;dd¡ej:;dejdFddń ZkddǄ ZlddɄ Zmdd˄ Znej:;ddFd dFgddpdpgfdFd dFgddpdpggfdFd dFggddpdpgfgdd΄ ZoddЄ Zpdd҄ ZqddԄ Zrddք Zsdd؄ Ztddڄ Zudd܄ Zvddބ Zwdd Zxej:;dd,d gdd Zydd Zzdd Z{dd Z|ej:;dd dFdGd`gdFdfd dFdGd`gdGdpfd dFdGd`gd`dfgdd Z}ej:;deHdddFdFgdFdFfeHddFddFgdFdpfeHddFddFgdGdFfeHddddgdFdFfeHddddgdFdpfeHddddgdGdFfgdd Z~ej:;deHd dFdFdGgdd dFdGd`gfeHd dFdFdFgdpd dFdGd`gfeHdFdFdFdFgdpd dFdGd`gfeHdHdbdbdHgddHdIdMdbgfgej:;dd,d gdd Zdd Zej:;dd dFdGd`gdFdfd dFdGd`gdGdpfd dFdGd`gd`dFfgdd Zej:;dd dFdGd`gdfd dFdGd`gdfgdd Zej:;dd ddFdGgdd dgdddgdddgddpd ggddfd dFdGd`gdd dgdddgdddgddpd ggddfdJdJdHdIgdd dgdddgdddgddpd ggdHdIdJdJgdfdJdJdHdIgdd dgdddgdddgddpd ggdHdJdIgdfd d dFdGgdd dgdddgdddgddpd ggd dFdGd`gdfd d dFdGgdd dgdddgdddgddpd ggd dFd`gdfd dFgdpddgdddggddfgd	d
 Zdd ZdS (      N)
csr_matrix)stats)datasets)svm)softmax)make_multilabel_classification)_sparse_random_matrix)check_arraycheck_consistent_length)check_random_state)assert_allcloseassert_almost_equal)assert_array_equal)assert_array_almost_equal)accuracy_scoreauc)average_precision_score)coverage_error)	det_curve)%label_ranking_average_precision_score)precision_recall_curve)label_ranking_loss)roc_auc_score)	roc_curve)_ndcg_sample_scores_dcg_sample_scores)
ndcg_score	dcg_score)top_k_accuracy_score)UndefinedMetricWarning)train_test_split)LogisticRegression)label_binarizeFc                 C   s  | dkrt  } | j}| j}|r:||dk  ||dk   }}|j\}}t|}td}|| || ||  }}t	|d }tj
d}tj|||d| f }tjdddd}	|	|d| |d| ||d }
|r|
ddd	f }
|	||d }||d }|||
fS )
zMake some classification predictions on a toy dataset using a SVC

    If binary is True restrict to a binary classification problem instead of a
    multiclass classification problem
    N   %   r      ZlinearT)ZkernelZprobabilityrandom_state   )r   Z	load_irisdatatargetshapenparanger   shuffleintrandomRandomStateZc_Zrandnr   ZSVCfitpredict_probaZpredict)ZdatasetbinaryXy	n_samples
n_featuresprngZhalfclfy_scorey_predy_true r@   [/home/sam/Atlas/atlas_env/lib/python3.8/site-packages/sklearn/metrics/tests/test_ranking.pymake_prediction4   s*    


*rB   c                 C   sd   t | d }|| |k }|| |k }|dd|dd }t |dk}|tt|t|  S )zKAlternative implementation to check for correctness of
    `roc_auc_score`.r)   r   )r-   uniquereshapesumfloatlen)r?   r=   	pos_labelposnegZdiff_matrixZ	n_correctr@   r@   rA   _aucd   s    rL   c           	      C   s   t | d }t | |k}t |ddd }|| }| | } d}tt|D ]P}| | |krPd}td|d D ]}| | |krr|d7 }qr||d  }||7 }qP|| S )a>  Alternative implementation to check for correctness of
    `average_precision_score`.

    Note that this implementation fails on some edge cases.
    For example, for constant predictions e.g. [0.5, 0.5, 0.5],
    y_true = [1, 0, 0] returns an average precision of 0.33...
    but y_true = [0, 0, 1] returns 1.0.
    r)   NrC   r         ?)r-   rD   rF   ZargsortrangerH   )	r?   r=   rI   Zn_posorderscoreiprecjr@   r@   rA   _average_precisions   s    	

rT   c                 C   sd   t | |\}}}tt|}tt|}d}tdt|D ]$}||| || ||d    7 }q:|S )ao  A second alternative implementation of average precision that closely
    follows the Wikipedia article's definition (see References). This should
    give identical results as `average_precision_score` for all inputs.

    References
    ----------
    .. [1] `Wikipedia entry for the Average precision
       <https://en.wikipedia.org/wiki/Average_precision>`_
    r   r)   )r   listreversedrN   rH   )r?   r=   	precisionZrecall	thresholdZaverage_precisionrQ   r@   r@   rA   _average_precision_slow   s    
"rY   c                 C   s^   dd }|| ||\}}t ||}d}|}d||  ||  }	|| }
dd||	 |
|	    S )zcAlternative implementation to check for correctness of `roc_auc_score`
    with `max_fpr` set.
    c                 S   s   t | |\}}}|||k }t||}|||k }t||k}|d }	||	 || g}
||	 || g}t|t||
|}||fS )Nr)   )r   r-   appendargmaxZinterp)r?   	y_predictmax_fprfprtpr_new_fprnew_tprZidx_outZidx_inZx_interpZy_interpr@   r@   rA   _partial_roc   s    z,_partial_roc_auc_score.<locals>._partial_rocr         ?r)   r   )r?   r\   r]   rc   ra   rb   Zpartial_aucZfpr1Zfpr2Zmin_areaZmax_arear@   r@   rA   _partial_roc_auc_score   s    
re   dropTc           	      C   sz   t dd\}}}t||}t||| d\}}}t||}t||dd t|t|| |j|jksft|j|jksvtd S )NTr5   Zdrop_intermediater%   decimal)	rB   rL   r   r   r   r   r   r,   AssertionError)	rf   r?   r`   r=   Zexpected_aucr^   r_   
thresholdsroc_aucr@   r@   rA   test_roc_curve   s    

rn   c                  C   s   t jd} t dgd dgd  }| jddd}t||dd\}}}|d dksXt|d	 dksht|j|jksxt|j|jkstd S )
Nr   2   r)      d   sizeTrh   rC   )r-   r1   r2   arrayrandintr   rk   r,   )r;   r?   r>   r^   r_   Zthrr@   r@   rA   test_roc_curve_end_points   s    rv   c            
      C   s   t dd\} }}t| |\}}}g }|D ]2}t||k| @ }t| }	|d| |	  q(t||dd |j|jkszt|j|jkstd S )NTrg   rM   r%   ri   )rB   r   r-   rF   rZ   r   r,   rk   )
r?   r`   r=   r^   r_   rl   Ztpr_correctttpr:   r@   r@   rA   test_roc_returns_consistency   s    
ry   c               	   C   s4   t dd\} }}tt t| | W 5 Q R X d S )NFrg   )rB   pytestraises
ValueErrorr   )r?   r`   r=   r@   r@   rA   test_roc_curve_multi   s    r}   c                  C   s`   t dd\} }}t| |d \}}}t||}t|ddd |j|jksLt|j|jks\td S )NTrg   rd   ?r%   ri   )rB   r   r   r   r,   rk   )r?   r`   r=   r^   r_   rl   rm   r@   r@   rA   test_roc_curve_confidence   s    
r   c                  C   s  t dd\} }}t| j}t| |\}}}t||}t|ddd |j|jksTt|j|jksdtt| j}t| |\}}}t||}t|ddd |j|jkst|j|jkstt| |\}}}t||}t|ddd |j|jkst|j|jkstd S )NTrg   rd   r%   ri   g(\?)	rB   r-   onesr,   r   r   r   rk   zeros)r?   predr=   Ztrivial_predr^   r_   rl   rm   r@   r@   rA   test_roc_curve_hard  s$    


r   c               
   C   s  ddddddddddg
} ddddddddddg
}d}t jt|d t| |\}}}W 5 Q R X t|tt|tj |j	|j	kst
|j	|j	kst
d}t jt|d  tdd | D |\}}}W 5 Q R X t|tt|tj |j	|j	kst
|j	|j	kst
d S )Nr)   r   INo negative samples in y_true, false positive value should be meaninglessmatchHNo positive samples in y_true, true positive value should be meaninglessc                 S   s   g | ]}d | qS )r)   r@   ).0xr@   r@   rA   
<listcomp>1  s     z,test_roc_curve_one_label.<locals>.<listcomp>)rz   warnsr!   r   r   r-   fullrH   nanr,   rk   )r?   r>   expected_messager^   r_   rl   r@   r@   rA   test_roc_curve_one_label  s     $r   c               	   C   s  ddg} ddg}t | |\}}}t| |}t|dddg t|dddg t|d ddg} ddg}t | |\}}}t| |}t|dddg t|dddg t|d ddg} ddg}t | |\}}}t| |}t|ddg t|ddg t|d ddg} ddg}t | |\}}}t| |}t|dddg t|dddg t|d ddg} ddg}t | |\}}}t| |}t|ddg t|ddg t|d ddg} ddg}d}tjt|d	 t | |\}}}W 5 Q R X tt t| | W 5 Q R X t|dddg t|t	j
t	j
t	j
g ddg} ddg}d
}tjt|d	 t | |\}}}W 5 Q R X tt t| | W 5 Q R X t|t	j
t	j
t	j
g t|dddg t	ddgddgg} t	ddgddgg}tt t| |dd W 5 Q R X tt t| |dd W 5 Q R X tt| |ddd tt| |ddd t	ddgddgg} t	ddgddgg}tt t| |dd W 5 Q R X tt t| |dd W 5 Q R X tt| |ddd tt| |ddd t	ddgddgg} t	ddgddgg}tt| |ddd tt| |ddd tt| |ddd tt| |ddd t	ddgddgg} t	ddgddgg}tt| |ddd tt| |ddd tt| |ddd tt| |ddd d S )Nr   r)   rM           rd         ?      ?r   r   r   macroaverageweightedsamplesmicro)r   r   r   r   rz   r   r!   r{   r|   r-   r   rt   )r?   r=   r_   r^   r`   rm   r   r@   r@   rA   test_roc_curve_toydata8  s    









r   c                  C   s   ddddddg} ddddddg}t | |d	d
\}}}t|ddddg dddddddddddddg} dddddddddddddg}t | |d	d
\}}}t|ddddddg d S )Nr   r)   r   皙?rd   333333?ffffff?rM   Trh          @皙?皙?r~   )r   r   )r?   r=   r_   r^   rl   r@   r@   rA    test_roc_curve_drop_intermediate  s    r   c                  C   st   dddddg} dddddg}t dd	}t| ||d
\}}}t |dk  dksVtt |dk  dksptd S )Nr   r)   r   r   333333?皙?rd   r      sample_weight)r-   repeatr   diffrF   rk   )r?   r=   r   r^   r_   r`   r@   r@   rA   !test_roc_curve_fpr_tpr_increasing  s    r   c                  C   s   ddg} ddg}t t| |d ddg} ddg}t t| |d dddg} dddg}t t| |d ddg} ddg}t t| |d dddg} dddg}t t| |d d S Nr   r)   rd   )r   r   )r   r7   r@   r@   rA   test_auc  s    



r   c               	   C   s   t t tdddgddg W 5 Q R X t t tdgdg W 5 Q R X dddd	g} d
dddg}dt| }t jtt|d t| | W 5 Q R X d S )Nr   rd   rM   r   r   r%   r)   rp      r            z+x is neither increasing nor decreasing : {}r   )	rz   r{   r|   r   formatr-   rt   reescape)r   r7   error_messager@   r@   rA   test_auc_errors  s    r   zy_true, labelsr)   r%   abcc              	   C   s^  t dddgdddgdddgdd	dgg}td
dd
gdddg}tdd
dgdddg}|| d }td
d
dgdddg}tddd
gdddg}|| d }td
dgdd	g}	tdd
gddg}
|	|
 d }|| | d }tt| ||dd| |||g}dddg}t j||d}tt| ||ddd| d}tjt|d t| ||dd d W 5 Q R X d S )Nr   r   r   r   ffffff?rd   333333?r   r   r)   r%   rp   ovolabelsmulti_classr   )weightsr   r   r   r   z6average=None is not implemented for multi_class='ovo'.r   )r-   rt   r   r   r   rz   r{   NotImplementedError)r?   r   y_scoresscore_01score_10Zaverage_score_01Zscore_02Zscore_20Zaverage_score_02Zscore_12Zscore_21Zaverage_score_12Zovo_unweighted_scoreZpair_scoresZ
prevalenceZovo_weighted_scorer   r@   r@   rA   #test_multiclass_ovo_roc_auc_toydata  sB    "

    r   dc                 C   s   t dddgdddgdddgdddgg}tdd	dd	gddddg}td	dd	dgddddg}|| d
 }tt| ||dd| tt| ||ddd| d S )Nr   r   r   r   r   g?g?r)   r   r%   r   r   r   r   r-   rt   r   r   )r?   r   r   r   r   Z	ovo_scorer@   r@   rA   *test_multiclass_ovo_roc_auc_toydata_binary  s(    "     r   c                 C   s   t dddgdddgdddgdddgg}tdd	d	d	g|d d d	f }td	dd	d	g|d d df }td	d	ddg|d d d
f }tt| |d|d d|||g || | d }tt| |d|d| |d |d  |d  }tt| |d|dd| d S )NrM   r   r   rd   r   r   r   r)   r   r%   ovr)r   r   r   g      @)r   r   r   r   r   )r?   r   r   Zout_0Zout_1Zout_2Zresult_unweightedZresult_weightedr@   r@   rA   #test_multiclass_ovr_roc_auc_toydata=  s4    "    r   zmulti_class, average)r   r   )r   r   )r   r   c                 C   s   t ddddg}ddddgddddgddddgdddd	gg}tt||| |d
d ddddgddddgddddgddddgg}t||| |d
dk stdt d }t||| |d
tdkstd S )Nrp   r)   r%   r   r   rM   r   g?r   r   r   r   )r   r   rd   )r-   rt   r   r   rk   r   rz   approx)r   r   r?   Z	y_perfectZy_imperfectZy_chancer@   r@   rA   0test_perfect_imperfect_chance_multiclass_roc_auch  s8    	







   r   c           	         s   |  t jjdddgd d}t fdd|D }t|dd	d
gd}t| | \}}}t||}t	||ddd}|t
|kstd S )Nr   rM   rd     )rs   r(   c                    s"   g | ]}t jjd | d qS )r)   )nr:   r(   )r   Zmultinomialrvsr[   )r   Zy_pred_iseedr@   rA   r     s   z3test_micro_averaged_ovr_roc_auc.<locals>.<listcomp>r   r)   r%   )classesr   r   r   )r   Z	dirichletr   r-   asarrayr$   r   Zravelr   r   rz   r   rk   )	Zglobal_random_seedr>   r?   Zy_onehotr^   r_   r`   Zroc_auc_by_handZroc_auc_autor@   r   rA   test_micro_averaged_ovr_roc_auc  s    

r   zmsg, y_true, labelsz!Parameter 'labels' must be uniquezKNumber of classes in y_true not equal to the number of columns in 'y_score'z"Parameter 'labels' must be orderedzMNumber of given labels, 2, not equal to the number of columns in 'y_score', 3zMNumber of given labels, 4, not equal to the number of columns in 'y_score', 3rp   z2'y_true' contains labels not in parameter 'labels'er   r   r   c              	   C   sX   t dddgdddgdddgdd	dgg}tjt| d
 t||||d W 5 Q R X d S )Nr   r   r   r   r   rd   r   r   r   r   r   )r-   rt   rz   r{   r|   r   )msgr?   r   r   r   r@   r@   rA   *test_roc_auc_score_multiclass_labels_error  s
    ?"r   zmsg, kwargszLaverage must be one of \('macro', 'weighted', None\) for multiclass problemsr   )r   r   zUaverage must be one of \('micro', 'macro', 'weighted', None\) for multiclass problemszksample_weight is not supported for multiclass one-vs-one ROC AUC, 'sample_weight' must be None in this case)r   r   z|Partial AUC computation not available in multiclass setting, 'max_fpr' must be set to `None`, received `max_fpr=0.5` insteadrd   )r   r]   zbmulti_class='ovp' is not supported for multiclass ROC AUC, multi_class must be in \('ovo', 'ovr'\)Zovpz'multi_class must be in \('ovo', 'ovr'\)c              	   C   sX   t d}|dd}t|}|jdddd}tjt| d t||f| W 5 Q R X d S )N     rp   r   rr   r   )r   randr   ru   rz   r{   r|   r   )r   kwargsr;   r=   Zy_probr?   r@   r@   rA   #test_roc_auc_score_multiclass_error  s    .r   c               
   C   st  t d} | d}tjddd}d}tjt|d t|| W 5 Q R X tjddd}tjt|d t|| W 5 Q R X tj	dddd}tjt|d t|| W 5 Q R X t
jdd	 t d} | d}tjddd}tjt|d t|| W 5 Q R X tjddd}tjt|d t|| W 5 Q R X tj	dddd}tjt|d t|| W 5 Q R X W 5 Q R X d S )
Nr   
   r0   ZdtypezROC AUC score is not definedr   rC   T)record)r   r   r-   r   rz   r{   r|   r   r   r   warningscatch_warnings)r;   r>   r?   err_msgr@   r@   rA   test_auc_score_non_binary_class#  s0    

r   
curve_funcc              	   C   sN   t d}|jdddd}|d}d}tjt|d | || W 5 Q R X d S )Nr   r   rp   r   rr   z"multiclass format is not supportedr   )r   ru   r   rz   r{   r|   )r   r;   r?   r>   r   r@   r@   rA   &test_binary_clf_curve_multiclass_errorC  s    
r   c              	   C   s   d}t jt|d" | tjddgddddg W 5 Q R X t jt|d" | tjddgtdddg W 5 Q R X d	}t jt|d" | tjd
dgddddg W 5 Q R X ddddg}| ddddg|}| ddddg|}t||D ]\}}tj|| qd S )Nzy_true takes value in {'a', 'b'} and pos_label is not specified: either make y_true take value in {0, 1} or {-1, 1} or pass pos_label explicitly.r   r   r   z<U1r   r   rM   zy_true takes value in {b'a', b'b'} and pos_label is not specified: either make y_true take value in {0, 1} or {-1, 1} or pass pos_label explicitly.   a   bz<S1r   gzG?r   r)   )	rz   r{   r|   r-   rt   objectziptestingr   )r   r   r>   Z	int_curveZfloat_curveZint_curve_partZfloat_curve_partr@   r@   rA   (test_binary_clf_curve_implicit_pos_labelM  s    &&&r   c                 C   s   dddddg}dddddg}dddddg}| |||d}| |d d	 |d d	 |d d	 d}t ||D ]\}}t|| qhd S )
Nr   r)   r   r   r   r   rd   r   rC   )r   r   )r   r?   r=   r   Zresult_1Zresult_2Zarr_1Zarr_2r@   r@   rA   (test_binary_clf_curve_zero_sample_weightq  s    &r   c            	   	   C   s4  t dd\} }}t| | t| dd  |dd  \}}}|d dksJt|d | dd   ksftd| t| dk< |  }t| | t||  ddddg}dddd	g}t||\}}}t	|t
d
dd
ddg t	|t
dd
d
d
dg t	|t
dddd	g |j|jkst|j|jd ks0td S )NTrg   r)   r   rM   rC   r%   rp   r   rd   gQUU?r   )rB   _test_precision_recall_curver   rk   meanr-   wherecopyr   r   rt   rs   )	r?   r`   r=   r:   rrw   Zy_true_copyr   Zpredict_probasr@   r@   rA   test_precision_recall_curve~  s"    
 

r   c                 C   s   t | |\}}}t| |}t|dd t|t| | tt| ||dd |j|jksZt|j|jd ksntt | t	|\}}}|j|jkst|j|jd kstd S )Ngrh|?rp   r%   ri   r)   )
r   rY   r   r   r   rT   rs   rk   r-   
zeros_like)r?   r=   r:   r   rl   Zprecision_recall_aucr@   r@   rA   r     s"    
   r   c               
   C   s  t jdd ddg} ddg}t| |\}}}t| |}t|dddg t|dddg t|d ddg} ddg}t| |\}}}t| |}t|dddg t|dddg t|d ddg} ddg}t| |\}}}t| |}t|ddg t|ddg t|d ddg} ddg}t| |\}}}t| |}t|dddg t|dddg t|d ddg} ddg}t| |\}}}t| |}t|ddg t|ddg t|d ddg} dd	g}tjtd
d t| |\}}}W 5 Q R X tjtd
d t| |}W 5 Q R X t	|dddg t	|dddg t	|d ddg} dd	g}t| |\}}}tt| |d t|dddg t|dddg t 
ddgddgg} t 
ddgddgg}tjtd
d t	t| |ddd W 5 Q R X tjtd
d t	t| |ddd W 5 Q R X t	t| |ddd t	t| |ddd t 
ddgddgg} t 
ddgddgg}tjtd
d t	t| |ddd W 5 Q R X tjtd
d t	t| |ddd W 5 Q R X t	t| |ddd	 t	t| |ddd t 
ddgddgg} t 
ddgddgg}tt| |ddd tt| |ddd tt| |ddd tt| |ddd t 
ddgddgg} t 
ddgddgg}tjtd
d t	t| |ddd W 5 Q R X t	t| |ddd tjtd
d t	t| |ddd W 5 Q R X tjtd
d t	t| |ddd W 5 Q R X t 
ddgddgg} t 
ddgddgg}t	t| |ddd t	t| |ddd t	t| |ddd t	t| |ddd t 
ddgddgg} t 
ddgddgg}tt| |ddd tt| |ddd tt| |ddd tt| |ddd W 5 Q R X t jdd` t 
ddgddgg} t 
ddgddgg}tjtd
d t	t| |ddd W 5 Q R X W 5 Q R X d S )Nraise)allr   r)   rd   rM   r   r   r   z!No positive class found in y_truer   r   r   r   r   r   ignore)r-   Zerrstater   r   r   r   rz   r   UserWarningr   rt   )r?   r=   r:   r   r`   Zauc_prcr@   r@   rA   #test_precision_recall_curve_toydata  sD   










                     r   c                  C   s<   t jdtd} d| d d d< t d}t| |dks8td S )Nrq   r   r)   r   r   )r-   r   r0   r   r   rk   r?   r=   r@   r@   rA   &test_average_precision_constant_valuesO  s    
r   c               	   C   s   t ddg} t ddg}d}tjt|d t| |dd W 5 Q R X t ddgddgddgddgg} t ddgddgd	d
gd
d	gg}d}tjt|d t| |dd W 5 Q R X d S )Nr   r)   z>pos_label=2 is not a valid label. It should be one of \[0, 1\]r   r%   rI   r~   r   r   r   znParameter pos_label is fixed to 1 for multilabel-indicator y_true. Do not set pos_label or set pos_label to 1.)r-   rt   rz   r{   r|   r   r?   r>   r   r@   r@   rA   -test_average_precision_score_pos_label_errors]  s    ""r   c                  C   s   t dd\} }}t| |}t| d| }t| d| }t| |d }||ksPt||ks\t||kshtt| |}t| d| }t| d| }	t| |d }
||kst||	kst||
kstd S )NTrg   rq   gư>r   )rB   r   rk   r   )r?   r`   r=   rm   Zroc_auc_scaled_upZroc_auc_scaled_downZroc_auc_shiftedZpr_aucZpr_auc_scaled_upZpr_auc_scaled_downZpr_auc_shiftedr@   r@   rA   test_score_scale_invariancep  s    

r   z(y_true,y_score,expected_fpr,expected_fnrr   r   r   c                 C   s(   t | |\}}}t|| t|| d S Nr   r   r?   r=   Zexpected_fprZexpected_fnrr^   fnrr`   r@   r@   rA   test_det_curve_toydata  s    
r  c                 C   s(   t | |\}}}t|| t|| d S r   r   r   r@   r@   rA   test_det_curve_tie_handling  s    
r  c                
   C   s>   t tdddgdddgtddddddgddddddg d S r   )r   r   r@   r@   r@   rA   test_det_curve_sanity_check  s     r  r=   c                 C   sN   t ddddddgtd| d\}}}t|dg t|dg t|| g d S )Nr   r)   r   r   )r   r-   r   r   )r=   r^   r  rX   r@   r@   rA   test_det_curve_constant_scores  s     
r  r?   c                 C   s.   t | | d\}}}t|dg t|dg d S )Nr   r   r   )r?   r^   r  r`   r@   r@   rA   test_det_curve_perfect_scores  s    r  zy_true, y_pred, err_msgzinconsistent numbers of samplesz Only one class present in y_truecancer
not cancerr   r   r   zpos_label is not specifiedc              	   C   s(   t jt|d t| | W 5 Q R X d S )Nr   )rz   r{   r|   r   r   r@   r@   rA   test_det_curve_bad_input  s    r	  c            	      C   s   dgd dgd  } t ddddd	dd
dddg
}d| }t| |dd\}}}t| |dd\}}}|d tdksxt|d tdkstt||d d d  t||d d d  d S )Nr  rp   r  r   r   r   r   r   r   rd   r   r~   r)   r   r   rC   )r-   rt   r   rz   r   rk   r   )	r?   Zy_pred_pos_not_cancerZy_pred_pos_cancerZfpr_pos_cancerZfnr_pos_cancerZth_pos_cancerZfpr_pos_not_cancerZfnr_pos_not_cancerZth_pos_not_cancerr@   r@   rA   test_det_curve_pos_label  s"    r
  c                 C   sP  t | ddggddggd t | ddggddggd t | ddggddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | ddggddggd t | ddggddggd t | ddggddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd t | dddggdddggd	 t | ddddggddddggd d S )
Nr   r)   r   r   rd   gUUUUUU?g?g?UUUUUU?r   
lrap_scorer@   r@   rA   check_lrap_toy  sj                                r  c                 C   s   t d}tddD ]}|jd|fd}t|}td|f}| ||dksPt| ||dksbttd|f}| ||dkst| ||dkstqt| dgdgdgdggdgdgdgdggd d S )Nr   r%   r   r)   rr   rM   rd   )	r   rN   uniformr-   r   r   rk   r   r   )r  r(   n_labelsr=   Zy_score_tiesr?   r@   r@   rA   !check_zero_or_all_relevant_labelsK  s    
( r  c              	   C   s  t t | dddgdddg W 5 Q R X t t. | dddgdddgd	ddgd
ddgg W 5 Q R X t t. | dddgdddgd	ddgd
ddgg W 5 Q R X t t  | ddgddggddg W 5 Q R X t t" | ddgddggddgg W 5 Q R X t t$ | ddgddggdgdgg W 5 Q R X t t" | ddggddgddgg W 5 Q R X t t$ | dgdggddgddgg W 5 Q R X t t$ | ddgddggdgdgg W 5 Q R X d S )Nr   r)   r   r   r   r%   r   r   r   r   )rz   r{   r|   r  r@   r@   rA   check_lrap_error_raisedb  s*     2 $&(&(r  c              	   C   sz   t ddD ]j}td|f}t d|D ]L}t || D ]:}td|f}d|d||| f< t| ||||  q6q&q
d S )Nr%   r   r)   r   )rN   r-   r   r   r   )r  r  r=   
n_relevantrJ   r?   r@   r@   rA   check_lrap_only_ties|  s    r  c              	      s   t ddD ]}|t|d|fd  }td|f}d|d< d|d< t| ||d| d d  t d|D ]` t |  D ]Ntd|f}d|d  f< t| ||t fddt  D  q~qnq
d S )	Nr%   r   r)   )r   r   )r   rC   r   c                 3   s&   | ]}|d  | d     V  qdS )r)   Nr@   r   r   r  rJ   r@   rA   	<genexpr>  s   z>check_lrap_without_tie_and_increasing_score.<locals>.<genexpr>)rN   r-   r.   rE   r   r   rF   )r  r  r=   r?   r@   r  rA   +check_lrap_without_tie_and_increasing_score  s     r  c                    s
  t | | t| } t|}| j\}}t|f}t|D ]}tj|| dd\}}|j}|| tj|d d	 }	|	 | | 
 d }
|
jdks|
j|krd||< q8d||< |
D ]4 t fdd|
D }||  |   7  < q||  |
j  < q8| S )	z8Simple implementation of label ranking average precisionT)Zreturn_inverser)   )Z	minlengthr   r   c                 3   s   | ]}|   kV  qd S r   r@   r  labelZrankr@   rA   r    s     z_my_lrap.<locals>.<genexpr>)r
   r	   r,   r-   emptyrN   rD   rs   ZbincountZcumsumZnonzerorF   r   )r?   r=   r8   r  rP   rQ   Zunique_rankZinv_rankZn_ranksZ	corr_rankZrelevantZn_ranked_abover@   r  rA   _my_lrap  s*    

r  r   r   c           	      C   s   t dd|||d\}}t|jd |jd |d}t|drB| }t||}t||}t|| t|}|j	||fd}t||}t||}t|| d S )Nr)   F)r9   Zallow_unlabeledr(   	n_classesr8   r   )Zn_componentsr9   r(   toarrayrr   )
r   r   r,   hasattrr  r   r  r   r   r  )	r  r  r8   r(   r`   r?   r=   Z
score_lrapZscore_my_lrapr@   r@   rA   %check_alternative_lrap_implementation  s,    
	





r   checkfuncc                 C   s   | | d S r   r@   )r!  r"  r@   r@   rA   test_label_ranking_avp  s    r#  c                   C   s   t t d S r   )r  r   r@   r@   r@   rA   test_lrap_error_raised  s    r$  r8   )r)   r%   r   r   r  )r%   r   r   r(   c                 C   s   t t|| | d S r   )r   r   )r8   r  r(   r@   r@   rA   $test_alternative_lrap_implementation  s       r%  c                  C   s   t jddddgddddgddddggtd} t ddddgddddgddddgg}t dd	d
g}t d
d
dg}tt| ||dt || t |  d S )Nr)   r   r   r   r   r   r   rd   r   rM   r   r   )r-   rt   boolr   r   rF   )r?   r=   Zsamplewise_lrapsr   r@   r@   rA   &test_lrap_sample_weighting_zero_labels  s    ,   r'  c                   C   s$  t tddggddggd t tddggddggd t tddggddggd t tddggddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddgdddggdd	d
gdddggd t tdddgdddgdddggddd
gdddgdddggd t tdddgdddgdddggddd
gdddgdddggd d S )Nr   r)   r   r   r%   rd   rp   r         $@r   r   g@r   r   r@   r@   r@   rA   test_coverage_error  s\                            (  r+  c                   C   st  t tddggddggd t tddggddggd t tddggddggd t tddggddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd d S )Nr   rd   r)   r%   r   rp   r*  r@   r@   r@   rA   test_coverage_tie_handlingT  s           r,  zy_true, y_scorec              	   C   s(   t jtdd t| | W 5 Q R X d S )Nz'Expected 2D array, got 1D array insteadr   )rz   r{   r|   r   r   r@   r@   rA   test_coverage_1d_error_messaged  s    r-  c                	   C   s  t tddggddggd t tddggddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tddggddggd t tddggddggd t tddggddggd t tddggddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddgdddggddd	gddd
ggd t tdddgdddgdddggddd	gddd
gdddggd t tdddgdddgdddggddd	gd
dd
gdddggd t tttdddgdddggddd	gd
dd
ggd d S )Nr   r)   r   r   rd   rM   r   r(  r)  rp   r   r%   )r   r   r   r-   rt   r@   r@   r@   rA   test_label_ranking_losss  sR              (   r.  c                	   C   s4  t t  tddgddggddg W 5 Q R X t t" tddgddggddgg W 5 Q R X t t$ tddgddggdgdgg W 5 Q R X t t" tddggddgddgg W 5 Q R X t t$ tdgdggddgddgg W 5 Q R X t t$ tddgddggdgdgg W 5 Q R X d S )Nr   r)   )rz   r{   r|   r   r@   r@   r@   rA   $test_ranking_appropriate_input_shape  s    $&(&(r/  c                   C   s   t tddggddggd t tddggddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd t tdddggdddggd d S )Nr)   r   rd   r   )r   r   r@   r@   r@   rA   test_ranking_loss_ties_handling  s         r0  c                  C   sH   t ddd\} }| d }t|| tjdd\}}t|| d S Nr   r   )r(   r  r)   )r%   rq   r   )r   _test_dcg_score_forr-   r1   r2   random_sampler`   r?   r=   r@   r@   rA   test_dcg_score  s
    

r5  c                 C   s   t t | jd d }t| | }t| |}||k s>tt| | dd|k sXt|j| jd fksnt|j| jd fkst|tt 	| d d d d df | j
ddkstd S )Nr)   r%   r   kr   rC   Zaxis)r-   log2r.   r,   r   r   rk   rz   r   sortrF   )r?   r=   ZdiscountidealrP   r@   r@   rA   r2    s    

r2  c               	   C   sN  t t dg} t | j}t| |}t| |dd}dt t dd }|t|	 | 
  gkslt|t|| d d d d df  	 gkstd|dd	d f< t| |}t| |dd}|t|| d d d d df  	 gkst|t|d d 	 | dd	d f 
  |dd  	 | dd d	f 
   gksJtd S )
Nr   Tignore_tiesr)   r%   r   rC   r   rp   )r-   r   r.   r   r,   r   r9  rz   r   rF   r   rk   )r?   r=   ZdcgZdcg_ignore_tiesZ	discountsr@   r@   rA   test_dcg_ties  s"    
 .
.""r>  c               	   C   s>   t dd} t| | dddtt| | dddks:td S )N   )r%   r   rp   T)r7  r=  )r-   r.   rE   r   rz   r   rk   )r   r@   r@   rA   test_ndcg_ignore_ties_with_k  s    r@  c               	   C   sf   t dddddgg} t dddd	d
gg}d}tjt|d t| |tdksXtW 5 Q R X d S )Ng{Gzg(\gGz޿g(\?gQ?gQ?gףp=
?r   gQ?gHzG?zndcg_score should not be used on negative y_true values. ndcg_score will raise a ValueError on negative y_true values starting from version 1.4.r   g&x@)r-   rt   rz   r   FutureWarningr   r   rk   )r?   r=   r   r@   r@   rA   test_ndcg_negative_ndarray_warn  s    rB  c                  C   s   t ddd} | t jdjdd| jd }t| |}t| |dd	}|t	|ksZt
|t	d
kslt
|d7 }t| |t	d
kst
d S )NF   r   r   r   皙ɿr   rr   Tr<  rM   r   )r-   r.   rE   r1   r2   r  r,   r   rz   r   rk   )r?   r=   ZndcgZndcg_no_tiesr@   r@   rA   test_ndcg_invariant  s    
rE  r=  c              
   C   s"  dt dd d  }t t dddd}|t jdjdd	|jd
 }t||| dt	
dt t dd ksxtt||| dt	
dt t dd kstt||| dt	
dt t dd kstt||d| dt	
dt t dd kstt||| dt	
dt t dd  ks:tt||| dt	
dt t dd  ksntdt d }dt t dd  }t||| dt	
|t d kstt||| dt	
t dkstt||| dt	
|ks tt||| dt	
dkstd S )Nrp   r   r   r   rC   )r   r)   r   rD  r   rr   r<  r%   r)   r   )Zlog_baser=  )r   r   	   rM   )r-   eyeZtiler.   r1   r2   r  r,   r   rz   r   r9  rk   r   log10r   r   r   r   rF   )r=  r?   r=   Zy_score_noisyZexpected_dcg_scorer@   r@   rA   test_ndcg_toy_examples  s|               
  
  
rI  c                  C   sH   t ddd\} }| d }t|| tjdd\}}t|| d S r1  )r   _test_ndcg_score_forr-   r1   r2   r3  r4  r@   r@   rA   test_ndcg_score.  s
    

rK  c                 C   s   t | | }t | |}||k s$t| dkjdd}||  tt|  ksXt|| tt| ksxt||  tt	| ||  t	| | |   kst|| tt| kst|j
| j
d fkst|j
| j
d fkstd S )Nr   r)   r8  )r   r   rk   rz   r   r-   r   rF   r   r   r,   )r?   r=   r;  rP   Zall_zeror@   r@   rA   rJ  6  s    

$ 
 rJ  c               	   C   sJ  t ddddg} t| | dddks(tt| | dddks>ttt t| | dds\tW 5 Q R X tt t| | ddstW 5 Q R X tt t| | ddstW 5 Q R X t ddddg}t| |dd}t| |}||kstt| |d	dd
ksttdd\} }}t dddD ]"}t	t| ||dt
| || q"d S )Nr   r)   )r]   gMbP?gg?r   g{Gz?r   rd   Trg   g-C6?r   )r-   rt   r   rk   rz   r{   r|   rB   Zlinspacer   re   )r?   r   Zroc_auc_with_max_fpr_oneZunconstrained_roc_aucr>   r`   r]   r@   r@   rA   test_partial_roc_auc_scoreF  s(    

rL  zy_true, k, true_scorec              	   C   sV   t ddddgddddgddddgddddgg}t| ||d}|t|ksRtd S )Nr   r   r   r   r6  r-   rt   r    rz   r   rk   )r?   r7  
true_scorer=   rP   r@   r@   rA   test_top_k_accuracy_score`  s    	



rO  zy_score, k, true_scorerC   r   c                 C   s   ddddg}|   dkr(|  dkr(dnd}|dkrD| |ktjn|}t|| |d}t||}||  krzt|ksn t	d S )Nr   r)   rd   r6  )
minmaxZastyper-   Zint64r    r   rz   r   rk   )r=   r7  rN  r?   rX   r>   rP   Z	score_accr@   r@   rA    test_top_k_accuracy_score_binaryu  s     
rR  zy_true, true_score, labelslabels_as_ndarrayc              	   C   sf   |rt |}t ddddgddddgddddgddddgg}t| |d|d}|t|ksbtdS )z,Test when labels and y_score are multiclass.r   r   r   r   r%   r7  r   N)r-   r   rt   r    rz   r   rk   )r?   rN  r   rS  r=   rP   r@   r@   rA   0test_top_k_accuracy_score_multiclass_with_labels  s    




	rU  c                     s   t jddddd\ t dd\} }}}tdd| | t| |f||fD ]<\  fddtddD }tt	|dksTt
qTd S )	Nr   r   r   )r  r8   Zn_informativer(   )r(   c                    s    g | ]}t  |d qS )r6  )r    r4   )r   r7  r6   r<   r7   r@   rA   r     s    z8test_top_k_accuracy_score_increasing.<locals>.<listcomp>r%   )r   Zmake_classificationr"   r#   r3   r   rN   r-   r   r   rk   )ZX_trainZX_testZy_trainZy_testZscoresr@   rV  rA   $test_top_k_accuracy_score_increasing  s       

rW  c              	   C   sR   t ddddgddddgddddgddddgg}t| ||dt|ksNtd S )Nr   r   r   r)   rp   r6  rM  )r?   r7  rN  r=   r@   r@   rA   test_top_k_accuracy_score_ties  s    




rX  z	y_true, kr   c              	   C   sn   t ddddgddddgddddgddddgg}d}tjt|d t| ||d}W 5 Q R X |dksjtd S )	Nr   r   r   r   zu'k' \(\d+\) greater than or equal to 'n_classes' \(\d+\) will result in a perfect score and is therefore meaningless.r   r6  r)   )r-   rt   rz   r   r!   r    rk   )r?   r7  r=   r   rP   r@   r@   rA   !test_top_k_accuracy_score_warning  s    



	rY  zy_true, y_score, labels, msgg=
ףp=?r   r   z9y type must be 'binary' or 'multiclass', got 'continuous'zZNumber of classes in 'y_true' \(4\) not equal to the number of classes in 'y_score' \(3\).z"Parameter 'labels' must be unique.z#Parameter 'labels' must be ordered.zSNumber of given labels \(4\) not equal to the number of classes in 'y_score' \(3\).z3'y_true' contains labels not in parameter 'labels'.z}`y_true` is binary while y_score is 2d with 3 classes. If `y_true` does not contain all the labels, `labels` must be providedc              	   C   s.   t jt|d t| |d|d W 5 Q R X d S )Nr   r%   rT  )rz   r{   r|   r    )r?   r=   r   r   r@   r@   rA   test_top_k_accuracy_score_error  s    QrZ  c                  C   sR   t dddgdddgg} tdddgdddgg}t| |}|tdksNtd S )Nr)   r   rd   r~   r   r  )r   r-   rt   r   rz   r   rk   )r?   r=   resultr@   r@   rA   Otest_label_ranking_avg_precision_score_should_allow_csr_matrix_for_y_true_input?  s    
r\  )NF)r   r   r   )r   rz   numpyr-   r   Zscipy.sparser   Zscipyr   Zsklearnr   r   Zsklearn.utils.extmathr   Zsklearn.datasetsr   Zsklearn.random_projectionr   Zsklearn.utils.validationr	   r
   r   Zsklearn.utils._testingr   r   r   r   Zsklearn.metricsr   r   r   r   r   r   r   r   r   r   Zsklearn.metrics._rankingr   r   r   r   r    Zsklearn.exceptionsr!   Zsklearn.model_selectionr"   Zsklearn.linear_modelr#   Zsklearn.preprocessingr$   ZCURVE_FUNCSrB   rL   rT   rY   re   markZparametrizern   rv   ry   r}   r   r   r   r   r   r   r   r   rt   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r  r	  r
  r  r  r  r  r  r  r   r#  r$  rN   r%  r'  r+  r,  r-  r.  r/  r0  r5  r2  r>  r@  rB  rE  rI  rK  rJ  rL  rO  rR  rU  rW  rX  rY  rZ  r\  r@   r@   r@   rA   <module>   s  0
h	
2
	
"
"

=				
*
 
	
#
 $""""""""






;'    
 	7
3
&

	



























P