U
    md>                     @   s	  d Z ddlZddlZddlZddlZddlmZ ddl	Z	ddl
mZ ddl
mZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZ ddlmZmZmZ ddlmZ ddlm Z  ddlm!Z! ddl"m#Z# ddl"m$Z$ ddl"m%Z% ddl"m&Z& ddl"m'Z' ddl"m(Z( ddl"m)Z) ddl
m*Z* ddl+m,Z, ddl-m.Z. dZ/e	j01de/ Z2e3dddddgddd ddgdddddggZ4d!Z5e4j6\Z7Z8e,e5e4dd"d#\Z9Z:e;e9Z<e	j0j=d$ej3ej;gd%d&gd'e	j0=d(d)d*ge	j0=d+ej>ej?gd,d- Z@e	j0j=d$ej3ej;gd%d&gd'e	j0=d(d)d*gd.d/ ZAe	j0j=d$ej3ej;gd%d&gd'd0d1 ZBe	j0=d2d3d4ge	j0j=d$ej3ej;gd%d&gd'e	j0=d5d6d7d8dgd9d: ZCe	j0=d;d)d*gd<d= ZDe	j0=d;d>d?gd@dA ZEdBdC ZFdDdE ZGe	j0j=dFe9e<gd%d&gd'e	j0j=dGdHdIe4dJdK gdHdIdLdMgd'e	j0=dNeegdOdP ZHe	j0j=dGdHdIe4dQdK gdHdIdLdMgd'dRdS ZIe	j0=dNeegdTdU ZJdVdW ZKe	j0=d;d)d*ge	j0=d5d6dgdXdY ZLdZd[ ZMe	j0=dNeegd\d] ZNd^d_ ZOe	j0j=dFe9e<gd%d&gd'd`da ZPdbdc ZQddde ZRe	j0=dfdgdhgdidj ZSdkdl ZTdmdn ZUe	j0=dNeegdodp ZVe	j0j=d$ej3ej;gd%d&gd'e	j0=dqed)fed*fedfge	j0=drdsd!gdtdu ZWe	j0=dNeegdvdw ZXe	j0j=dGdHdIe4gdHdIdLgd'e	j0=dNeegdxdy ZYe	j0j=d$ej3ej;gd%d&gd'e	j0=d+ejZej[ge	j0=dGdIdLge	j0=dNeegdzd{ Z\e	j0=dNeegd|d} Z]e	j0=dNeegd~d Z^dd Z_dd Z`e	j0j=dFe9e<gd%d&gd'e	j0=dNeegdd Zae	j0=d+ejZej[ej>ej?ge	j0=dNeegdd Zbe	j0j=dFe9e<gd%d&gd'dd Zcdd Zddd Zedd Zfe	j0j=dFe9e<gd%d&gd'e	j0=dNeegdd Zge	j0j=dFe9e<gd%d&gd'e	j0=dNeegdd Zhdd Zie	j0j=d$ej3ej;gd%d&gd'dd Zje	j0=dNeegdd Zkdd Zle	j0j=d$ej3ej;gd%d&gd'e	j0=d(d)d*gdd Zme	j0=d+ej>ej?ge	j0=dddgdd Zne	j0=d+ej>ej?gdd Zoe	j0=dedfedfgdd Zpe	j0=dedfedfgdd Zqe	j0=dNeegdd Zre	j0=dNeege	j0=dde5d idfdGe9dds idfdGddK idfdGe9ddddsf idfdGddK idfgdd Zse	j0=dde9dds idfgdd Zte	j0=dFe9e<ge	j0=d+ej?ej>gdd Zue	j0=dee9dddgdd ZvddÄ Zwddń Zxe	j0=ddGeydIiddgddggddǜfddɄ Zze	j0=dedfedfedfgdd΄ Z{e	j0=dddgddф Z|dS )zTesting for K-means    N)sparse)assert_array_equal)assert_allclose)threadpool_limits)clone)ConvergenceWarning)	row_norms)pairwise_distances)pairwise_distances_argmin)v_measure_score)KMeansk_meanskmeans_plusplus)MiniBatchKMeans)_labels_inertia)_mini_batch_step)_relocate_empty_clusters_dense)_relocate_empty_clusters_sparse)_euclidean_dense_dense_wrapper)_euclidean_sparse_dense_wrapper)_inertia_dense)_inertia_sparse)_is_same_clustering)create_memmap_backed_data)
make_blobs)StringIOzThe default value of `n_init` will change from \d* to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning:FutureWarningzignore:        g      @      ?g      @d   *   )	n_samplescentersZcluster_stdrandom_statearray_constrZdenser   )Zidsalgolloydelkandtypec                 C   s   | ddgddgddgddgg|d}ddddg}t jddgddgg|d}ddddg}d}t jddgddgg|d}d	}	td	d||d
}
|
j||d t|
j| t|
j| t|
j| |
j	|	kst
d S )Nr         ?   r'      g      ?g      ?g      ?   
n_clustersn_initinit	algorithmsample_weight)nparrayr   fitr   labels_r   inertia_cluster_centers_n_iter_AssertionError)r#   r$   r'   Xr3   init_centersexpected_labelsexpected_inertiaexpected_centersexpected_n_iterkmeans rC   [/home/sam/Atlas/atlas_env/lib/python3.8/site-packages/sklearn/cluster/tests/test_k_means.pytest_kmeans_results;   s    $rE   c           	      C   s   | ddgddgddgddgg}t ddgddgg}tdd||d}|| d}d}t|j| |j|ksrtz8ddddg}ddgddgg}t|j	| t|j
| W nH tk
r   ddddg}dd	gdd
gg}t|j	| t|j
| Y nX d S )Nr   r(   r)   r+   r,   r-   g      ?g      ?r   r   )r4   r5   r   r6   r   r8   r:   r;   r   r7   r9   )	r#   r$   r<   r=   rB   rA   r?   r>   r@   rC   rC   rD   test_kmeans_relocated_clustersU   s$     
rF   c                 C   s   t ddddddddd	d
g
dd}| |}t d
}t dddgdd}t dddgdd}t dddg}t jd
t jd}| t jkrt|||||| nt|j|j	|j
||||| t|dddg t|dgd
gd	gg d S )Ng      $g      #ig      !ir)   	   g      #@
   ig     0g      $@r   r*      i)r4   r5   reshapeoneszerosint32r   r   dataindicesZindptrr   r   )r#   r<   r3   centers_oldcenters_newZweight_in_clusterslabelsrC   rC   rD   test_relocate_empty_clustersv   s8    &

     rT   distributionnormalZblobstolg{Gz?g:0yE>g0.++c           	      C   s   t j|}| dkr"|jdd}nt|d\}}d||dk < ||}td|d|d}td	d|d|d
}|| || t|j|j t	|j
|j
 |j|jkst|jtj|jddkstd S )NrV   i  rI   sizer"   r      r)   )r.   r"   r/   rW   r&   )r1   r.   r"   r/   rW   ư>)rel)r4   randomRandomStaterV   r   r   r6   r   r9   r   r7   r:   r;   r8   pytestZapprox)	rU   r#   rW   global_random_seedrndr<   _Zkm_lloydZkm_elkanrC   rC   rD   test_kmeans_elkan_results   s(    

re   r1   c                 C   sH   t j|}|jdd}d}t| d|dd|d|}|j|k sDtd S )NrX   rY   i,  r\   r)   r   )r1   r.   r"   r/   rW   max_iter)r4   r_   r`   rV   r   r6   r:   r;   )r1   rb   rc   r<   rf   kmrC   rC   rD   test_kmeans_convergence   s    	rh   autofullc              	   C   sV   t jdd}t| d}tjtd|  dd || |jdksHt	W 5 Q R X d S )Nr   r,   r1   zalgorithm='zB' is deprecated, it will be removed in 1.3. Using 'lloyd' instead.matchr%   )
r4   r_   Zrandr   ra   warnsFutureWarningr6   
_algorithmr;   )r1   r<   rB   rC   rC   rD   ,test_algorithm_auto_full_deprecation_warning   s    


rq   c              	   C   sv  t j| }t|jtjd }| }t |}t |}t j|jd t	j
d}t j|jd t	j
d}t jt	jd t	j
d}t	d d }	td d }
|d d }t|	||||t j| dd}|dkstt|	||\}}|dkst||k stt|
||||t j| dd}|dkstt|
||\}}|dks<t||k sJtt|| t|| t|| t|| d S )NrY   r   r*   rI   F)random_reassignr   )r4   r_   r`   r!   rV   shapecopyZ
zeros_likerM   r<   r'   rL   X_csrr   r;   r   r   r   )rb   rngrQ   Zcenters_old_csrrR   Zcenters_new_csrZweight_sumsZweight_sums_csrr3   ZX_mbZX_mb_csrZsample_weight_mbZold_inertiarS   Znew_inertiaZold_inertia_csrZ
labels_csrZnew_inertia_csrrC   rC   rD   !test_minibatch_update_consistency   sZ    


	
	  


rw   c                 C   sX   | j }|jttfkst| j}t|jd tks6ttt	t
|d | jdksTtd S )Nr   r   r   )r9   rs   r.   
n_featuresr;   r7   r4   uniquer   r   true_labelsr8   )rg   r!   rS   rC   rC   rD   _check_fitted_model#  s    r{   rO   r0   r_   	k-means++c                 C   s   t S Nr!   r<   kr"   rC   rC   rD   <lambda>4      r   ndarraycallable	Estimatorc                 C   s4   t |trdnd}| |td|d|}t| d S )NrI   r)   r   r0   r.   r"   r/   )
isinstancestrr.   r6   r{   )r   rO   r0   r/   rg   rC   rC   rD   test_all_init1  s    	   r   c                 C   s   t S r}   r~   r   rC   rC   rD   r   C  r   c                 C   sF   t | trdnd}t| td|d}tdD ]}|t q*t| d S )NrI   r)   r   r   r   )r   r   r   r.   rangepartial_fitr<   r{   )r0   r/   rg   irC   rC   rD   &test_minibatch_kmeans_partial_fit_initA  s       r   c                 C   s`   t t}t t}| ttd|dt}| t|d|d|}t|j|j t|j	|j	 d S )Nr)   r.   r0   r/   r"   )
r4   asfortranarrayr<   r!   r.   r6   r   r9   r   r7   )r   rb   	X_fortrancenters_fortranZkm_cZkm_frC   rC   rD   test_fortran_aligned_dataR  s(    

   r   c                  C   s8   t tddd} tj}t t_z| t W 5 |t_X d S )Nr   r)   )r.   r"   verbose)r   r.   sysstdoutr   r6   r<   )rg   Z
old_stdoutrC   rC   rD   test_minibatch_kmeans_verbosee  s    r   c              	   C   s   t jdjdd}t| tddd|dd| | }t	d|j
sJtt	d	|j
s\t|dkrxt	d
|j
stnt	d|j
std S )Nr   rX   rY   r   r_   r)   )r1   r.   r"   r0   r/   rW   r   zInitialization completezIteration [0-9]+, inertiazstrict convergencez center shift .* within tolerance)r4   r_   r`   rV   r   r.   r6   
readouterrresearchoutr;   )r1   rW   capsysr<   capturedrC   rC   rD   test_kmeans_verbosep  s$    
r   c                	   C   s0   t jtdd tdddt W 5 Q R X d S )Nz,init_size.* should be larger than n_clustersrl   rI      )	init_sizer.   )ra   rn   RuntimeWarningr   r6   r<   rC   rC   rC   rD   'test_minibatch_kmeans_warning_init_size  s
     r   c              	   C   s2   t jtdd | ttddt W 5 Q R X d S )NzAExplicit initial center position passed: performing only one initrl   rI   r0   r.   r/   )ra   rn   r   r!   r.   r6   r<   )r   rC   rC   rD   'test_warning_n_init_precomputed_centers  s
    r   c                 C   s   t dd| d\}}d|d d dd d f< tdd| dd	|}|jjd
d dksXttdd| dd	|}|jjd
d dksttd| dd}tdD ]}|| q|jjd
d dkstd S )Nr   r\   )r    r!   r"   r   r,   r   rI   r_   )r.   
batch_sizer"   r0   r)   Zaxis   )r.   r"   r0   )	r   r   r6   r9   anysumr;   r   r   )rb   Zzeroed_Xrz   rg   r   rC   rC   rD    test_minibatch_sensible_reassign  s8      
      r   c              
   C   s   t ttf}ttD ]}tt|k jdd||< qt t	}t 
|}t| ||dd  }t| |||t tt j|ddd t| ||dd  }||kstt| |||t tt j|ddd t|| d S )Nr   r   r)   T)rr   Zreassignment_ratiogV瞯<)r4   emptyr.   rx   r   r<   rz   meanrL   r    Z
empty_liker   r   rM   r_   r`   r;   r   )rO   rb   Zperfect_centersr   r3   rR   Zscore_beforeZscore_afterrC   rC   rD   test_minibatch_reassign  s:    



r   c                   C   s   t ddtdddt d S )Nr   rI   r   T)r.   r   r   r"   r   )r   r    r6   r<   rC   rC   rC   rD   &test_minibatch_with_many_reassignments  s    r   c                  C   sp   t ddddt} | jdks"tt ddddt} | jdksDtt dddtd dt} | jtksltd S )NrI   r\   r)   )r.   r   r/         )r.   r   r/   r   )r   r6   r<   Z
_init_sizer;   r    rg   rC   rC   rD   test_minibatch_kmeans_init_size  s       r   ztol, max_no_improvement)-C6?N)r   rI   c                 C   s   t dddd\}}}td|d|dddd|d	}|| d|j  k rNdk sTn t|  }|d krrd	|jksrt|dkrd
|jkstd S )Nr+   r   T)r!   r"   Zreturn_centersr   rI   r)   )	r.   r0   r   rW   r"   rf   r/   r   max_no_improvementz Converged (small centers change)z*Converged (lack of improvement in inertia))r   r   r6   r:   r;   r   r   )r   rW   r   r<   rd   r!   rg   r   rC   rC   rD   #test_minibatch_declared_convergence  s&    
r   c                  C   s   d} t jd }td| ddt }|jt|j|  | ks@tt	|jt
sPttd| ddd ddt }|jdksxt|jd| |  kstt	|jt
std S )Nr   r   r+   )r.   r   r"   rI   )r.   r   r"   rW   r   rf   )r<   rs   r   r6   r:   r4   ceilZn_steps_r;   r   int)r   r    rg   rC   rC   rD   test_minibatch_iter_steps'  s$    
	r   c                  C   s6   t  } tdtdd}||  t| t| t  d S )NFr   )Zcopy_xr.   r"   )r<   rt   r   r.   r6   r{   r   )Zmy_Xrg   rC   rC   rD   test_kmeans_copyx@  s
    
r   c                 C   s`   t j|dd}| d|dd}|||}| d|dd}|||}||ks\td S )Nr   rI   r)   )r/   r"   rf   )r4   r_   r`   randnr6   Zscorer;   )r   rb   r<   km1s1km2s2rC   rC   rD   test_score_max_iterK  s    r   zEstimator, algorithmrf   r,   c                 C   s   t ddd|d\}}|||d}| ddd||d}|d k	rF|j|d || |j}	||}
t|
|	 ||}
t|
|	 ||j}
t|
t	d d S )Nr   rI   r    rx   r!   r"   r*   r_   )r.   r0   r/   rf   r"   rk   )
r   
set_paramsr6   r7   predictr   fit_predictr9   r4   Zarange)r   r1   r#   rf   Zglobal_dtyperb   r<   rd   rg   rS   predrC   rC   rD   test_kmeans_predictX  s0       





r   c                 C   sl   t j|tf}| t|dd}|jt|d | t|dd}|jt|d t	|j
|j
 t|j|j d S Nr)   r.   r"   r/   r2   )r4   r_   r`   Zrandom_sampler    r.   r6   r<   ru   r   r7   r   r9   )r   rb   r3   Zkm_denseZ	km_sparserC   rC   rD   test_dense_sparse  s"        r   c                 C   s^   t |trdnd}| t||dd}|t t|t|j |t t|t|j d S )NrI   r)   r   r   )	r   r   r.   r6   ru   r   r   r<   r7   )r   r0   r/   rg   rC   rC   rD   test_predict_dense_sparse  s    

r   c           
   	   C   s   t ddgddgddgddgddgddgg}|||d	}|d
krFdnd}|d
kr^|d d n|}| d|||d}| tkr|jdd || |jjt jkstddddddg}	t	t
|j|	d | tkrt||}|jjt jkstd S )Nr   rI      rH   rG   r)   r,   rJ   r*   r   r   )r   r   )r4   r5   r   r   r6   r9   r'   float64r;   r   r   r7   r   r   )
r   r#   r'   r0   rb   X_denser<   r/   rg   r>   rC   rC   rD   test_integer_input  s&    .   
r   c                 C   sb   | t |dt}||j}t|t|j t| t	
t  |t}t|tt|j d S )Nr.   r"   )r.   r6   r<   	transformr9   r   r	   r   Zdiagonalr4   rM   )r   rb   rg   XtrC   rC   rD   test_transform  s    
r   c                 C   s8   | |dd tt}| |ddt}t|| d S )Nr)   )r"   r/   )r6   r<   r   Zfit_transformr   )r   rb   ZX1ZX2rC   rC   rD   test_fit_transform  s    r   c                 C   s:   t j}dD ]*}ttd|| ddt}|j|ks
tq
d S )N)r)   r\   rI   r_   r)   )r.   r0   r/   r"   rf   )r4   infr   r.   r6   r<   r8   r;   )rb   Zprevious_inertiar/   rg   rC   rC   rD   test_n_init  s    r   c                 C   s`   t ttd | d\}}}|jttfks(tt|jd tks@ttt	t
|d |dks\td S )N)r.   r3   r"   r   r   r   )r   r<   r.   rs   rx   r;   r4   ry   r   r   rz   )rb   Zcluster_centersrS   inertiarC   rC   rD   test_k_means_function  s       r   c           
      C   s0  | d|d}i }i }i }i }t jt jfD ]}|j|dd}	||	 |j||< ||	||< |j||< |j||< |jj	|kst
| tkr(||	dd  |jj	|ks(t
q(t|t j |t j dd t|t j |t j |t j  d d	 t|t j |t j |t j  d d	 t|t j |t j  d S )
Nr)   )r/   r"   Frt   r   r+   r   rtol)Zatol)r4   r   float32astyper6   r8   r   r9   r7   r'   r;   r   r   r   maxr   )
r   rO   rb   rg   r   r   r!   rS   r'   r<   rC   rC   rD   test_float_precision  s0    



(  r   c                 C   sJ   t j|dd}tj|dd}| |tdd}|| t|j|rFtd S )NFr   r)   r   )	r<   r   r!   r.   r6   r4   Zmay_share_memoryr9   r;   )r   r'   Z
X_new_typeZcenters_new_typerg   rC   rC   rD   test_centers_not_mutated'  s
    
r   c                 C   s8   t td| }t t|jdd| }t|j|j d S )N)r.   r)   r.   r0   r/   )r   r.   r6   r9   r   )rO   r   r   rC   rC   rD   test_kmeans_init_fitted_centers6  s    r   c              	   C   st   t ddgddgddgddgg}td| d}d}tjt|d* || t|jtt	dksft
W 5 Q R X d S )Nr   r)      r   zmNumber of distinct clusters \(3\) found smaller than n_clusters \(4\). Possibly due to duplicate points in X.rl   r+   )r4   asarrayr   ra   rn   r   r6   setr7   r   r;   )rb   r<   rg   msgrC   rC   rD   1test_kmeans_warns_less_centers_than_unique_points@  s    "
r   c                 C   s   t j| ddS Nr   r   )r4   sortr~   rC   rC   rD   _sort_centersQ  s    r   c                 C   s   t j| jddtd}t jt|dd}ttdt	| d}t
|jt|d}t |j|}t
||}t|j| t|j|j tt|jt|j d S )Nr)   r\   rY   r   r   )r0   r/   r.   r"   r2   )r4   r_   r`   randintr    repeatr<   r   r!   r.   r   r6   r7   r   r   r8   r   r9   )rb   r3   ZX_repeatrg   Zkm_weightedZrepeated_labelsZkm_repeatedrC   rC   rD   test_weighted_vs_repeatedU  s*         r   c                 C   s\   t t}| t|dd}t|j|d d}t|j||d}t|j|j t|j	|j	 d S r   )
r4   rL   r    r.   r   r6   r   r7   r   r9   )r   rO   rb   r3   rg   Zkm_noneZkm_onesrC   rC   rD   test_unit_weights_vs_no_weightso  s    
r   c                 C   sj   t j|jtd}| t|dd}t|j||d}t|j|d| d}t|j	|j	 t
|j|j d S )NrY   r)   r   r2   r(   )r4   r_   r`   uniformr    r.   r   r6   r   r7   r   r9   )r   rO   rb   r3   rg   Zkm_origZ	km_scaledrC   rC   rD   test_scaled_weights~  s    r   c                  C   s$   t dddt} | jdks td S )Nr&   r)   )r1   rf   )r   r6   r<   r:   r;   r   rC   rC   rD    test_kmeans_elkan_iter_attribute  s    r   c                 C   st   | dgdgg}ddg}t dgdgg}td|dd}|j||d tt|jdks\tt|j	dgdgg d S )	NrG   r)   gffffff?g?rI   r,   r   r2   )
r4   r5   r   r6   lenr   r7   r;   r   r9   )r#   r<   r3   r0   rg   rC   rC   rD   #test_kmeans_empty_cluster_relocated  s    r   c              	   C   s~   t j|}|jdd}tddd | t|d|j}W 5 Q R X tddd | t|d|j}W 5 Q R X t|| d S )N)2   rI   rY   r)   Zopenmp)ZlimitsZuser_apir   r,   )	r4   r_   r`   rV   r   r.   r6   r7   r   )r   rb   rc   r<   Zresult_1Zresult_2rC   rC   rD   #test_result_equal_in_diff_n_threads  s    r   c                	   C   s0   t jtdd tdddt W 5 Q R X d S )Nz9algorithm='elkan' doesn't make sense for a single clusterrl   r)   r&   )r.   r1   )ra   rn   r   r   r6   r<   rC   rC   rC   rD   test_warning_elkan_1_cluster  s
    r   c                 C   sz   t j|jdd}|d d }| |}dd }|||\}}tdd||dd|}|j}	|j}
t||	 t	||
 d S )N)r   r\   rY   r\   c                 S   sP   |  }t| |}t|jd D ]}| ||k jdd||< q t| |}||fS r   )rt   r
   r   rs   r   )r<   r0   Znew_centersrS   labelrC   rC   rD   	py_kmeans  s    

z+test_k_means_1_iteration.<locals>.py_kmeansr)   )r.   r/   r0   r1   rf   )
r4   r_   r`   r   r   r6   r7   r9   r   r   )r#   r$   rb   r<   r=   r   Z	py_labelsZ
py_centersZ	cy_kmeansZ	cy_labelsZ
cy_centersrC   rC   rD   test_k_means_1_iteration  s$        
r   squaredTFc                 C   s   t j|}tjdddd|| d}| d}|dj| dd}|d	  }|| d	  }|rh|nt 	|}t
|||}	t|j|j|||}
| t jkrd
nd}t|	|
|d t|	||d t|
||d d S )Nr)   r   r(   csrZdensityformatr"   r'   rG   Fr   r,   r   gHz>r   )r4   r_   r`   sptoarrayrK   r   r   r   sqrtr   r   rO   rP   r   r   )r'   r   rb   rv   Za_sparseZa_densebZb_squared_normexpectedZdistance_dense_denseZdistance_sparse_denser   rC   rC   rD   test_euclidean_distance  s4             r   c                 C   s|  t j|}tjdddd|| d}| }|dj| dd}|ddj| dd}|jddt jd	}|||  d
 j	dd}t 	|| }	t
||||dd}
t||||dd}| t jkrdnd}t|
||d t|
|	|d t||	|d d}||k}|| ||  d
 j	dd}t 	|||  }	t
||||d|d}
t||||d|d}t|
||d t|
|	|d t||	|d d S )Nr   rI   r(   r   r   Fr   r\   )rZ   r'   r,   r)   r   )	n_threadsr   r]   r   )r  Zsingle_label)r4   r_   r`   r   r   r   r   r   rN   r   r   r   r   r   )r'   rb   rv   ZX_sparser   r3   r!   rS   Z	distancesr   Zinertia_denseZinertia_sparser   r   maskrC   rC   rD   test_inertia  sd                       r  zKlass, default_n_initrI   r+   c              	   C   s   | dd}t   t dt |t W 5 Q R X | jdkrBdnd}d| d}|  }tjt|d	 |t W 5 Q R X d S )
Nr)   r/   errorr   rI   r+   z/The default value of `n_init` will change from z to 'auto' in 1.4rl   )	warningscatch_warningssimplefilterro   r6   r<   __name__ra   rn   )Klassdefault_n_initestr   rC   rC   rD   !test_change_n_init_future_warning*  s    


r  c                 C   s\   | ddd}| t |jdks$t| ddd}| t | jdkrP|jdksXndsXtd S )	Nri   r|   )r/   r0   r)   r_   r   rI   r+   )r6   r<   Z_n_initr;   r	  )r
  r  r  rC   rC   rD   test_n_init_auto;  s    

r  c                 C   sV   t dgdgdgg}t dddg}| dddj||d	 t|t dddg d S )
Nr)   r,   r   r(   g?g333333?r   r   r2   )r4   r5   r6   r   )r   r<   r3   rC   rC   rD   test_sample_weight_unchangedF  s    r  zparam, matchr.   r)   z#n_samples.* should be >= n_clusterszIThe shape of the initial centers .* does not match the number of clustersc                 C   s   | d d S )Nr,   rC   ZX_r   r"   rC   rC   rD   r   [  r   rJ   zUThe shape of the initial centers .* does not match the number of features of the datac                 C   s   | d dd df S )NrJ   r,   rC   r  rC   rC   rD   r   e  r   c              	   C   s:   | dd}t jt|d |jf |t W 5 Q R X d S )Nr)   r  rl   )ra   raises
ValueErrorr   r6   r<   )r   paramrm   rg   rC   rC   rD   test_wrong_paramsP  s    
r  x_squared_normszKThe length of x_squared_norms .* should be equal to the length of n_samplesc              	   C   s,   t jt|d tttf|  W 5 Q R X d S )Nrl   )ra   r  r  r   r<   r.   )r  rm   rC   rC   rD   !test_kmeans_plusplus_wrong_paramst  s    r  c                 C   s   |  |} t| t|d\}}|jd tks.t|dk s>t|| jd k sTt|jd tksft|jdd| jddk st|jdd| jddk sttt	|  || d S )Nr[   r   r   )
r   r   r.   rs   r;   allr   minr   r<   )rO   r'   rb   r!   rP   rC   rC   rD   test_kmeans_plusplus_output  s    
  
  r  )r   c                 C   s$   t tt| d\}}tt| | d S )N)r  )r   r<   r.   r   )r  r!   rP   rC   rC   rD   test_kmeans_plusplus_norms  s    r  c                 C   s<   t tt| d\}}tt}t |t| d\}}t|| d S )Nr[   )r   r<   r.   r4   r   r   )rb   Z	centers_crd   r   r   rC   rC   rD   test_kmeans_plusplus_dataorder  s    
  
r  c               	   C   s   t jddddddddgt jd} t| | ds0tt jddddddddgt jd}t| |ds`tt jddddddddgt jd}t| |drtd S )Nr)   r   r,   r*   r+   )r4   r5   rN   r   r;   )Zlabels1Zlabels2Zlabels3rC   rC   rD   test_is_same_clustering  s       r  kwargs)r0   r/   c                 C   sH   t jddgddgddgddggt jd}tf ddi| }|| dS )zZCheck that init works with numpy scalar strings.

    Non-regression test for #21964.
    r   r(   r)   r*   r.   r,   N)r4   r   r   r   r6   )r  r<   Z
clusteringrC   rC   rD   -test_kmeans_with_array_like_or_np_scalar_init  s    (r  zKlass, methodr6   r   c                    sR   | j   |  }t||t |jjd }| }t fddt|D | dS )z=Check `feature_names_out` for `KMeans` and `MiniBatchKMeans`.r   c                    s   g | ]}  | qS rC   rC   ).0r   
class_namerC   rD   
<listcomp>  s     z*test_feature_names_out.<locals>.<listcomp>N)	r	  lowergetattrr<   r9   rs   Zget_feature_names_outr   r   )r
  methodrB   r.   Z	names_outrC   r   rD   test_feature_names_out  s    
r&  	is_sparsec                 C   sb   t ddddd\}}| r"t|}t }||}t|j|_t|j|_||}t	|| dS )z_Check that predict does not change cluster centers.

    Non-regression test for gh-24253.
    r   rI   r   r   N)
r   r   
csr_matrixr   r   r   r9   r7   r   r   )r'  r<   rd   rB   Zy_pred1Zy_pred2rC   rC   rD   ,test_predict_does_not_change_cluster_centers  s    


r)  )}__doc__r   r   r  numpyr4   Zscipyr   r   ra   Zsklearn.utils._testingr   r   Zsklearn.utils.fixesr   Zsklearn.baser   Zsklearn.exceptionsr   Zsklearn.utils.extmathr   Zsklearn.metricsr	   r
   Zsklearn.metrics.clusterr   Zsklearn.clusterr   r   r   r   Zsklearn.cluster._kmeansr   r   Zsklearn.cluster._k_means_commonr   r   r   r   r   r   r   r   Zsklearn.datasetsr   ior   r   markfilterwarningsZ
pytestmarkr5   r!   r    rs   r.   rx   r<   rz   r(  ru   Zparametrizer   r   rE   rF   rT   re   rh   rq   rw   r{   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rN   Zint64r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r  r  r  r  r  r  Zstr_r  r&  r)  rC   rC   rC   rD   <module>   s  
   

 
  
  
 
) 
 

?
	





.

 
 #
   
 

$
	 
 

	 
 
+



	

	


 "

