U
    md0                     @   s  d dl Z d dlZd dlZd dlmZ d dlmZ d dlZd dl	m
Z
 d dl	mZ d dl	mZ d dl	mZ d dl	mZ d d	lmZ d d
lmZ d dlmZ d dlmZmZmZ d dlmZ d dlmZmZmZ d dlmZ d dlm Z  d dl!m"Z" d dl#m$Z$ d dl%m&Z& d dl'm(Z( d dl)m*Z* dd Z+dd Z,dd Z-e j./ddddd gd!d" Z0e j./ddddgd#d$ Z1e j./ddddgd%d& Z2e j./ddddd gd'd( Z3d)d* Z4d+d, Z5d-d. Z6d/d0 Z7e j./dddge j./d1de8e9gd2d3 Z:e j./dddge j./d4d5d6gd7d8 Z;e j./dd dge j./d1e9e<d9e<d:gd;d< Z=d=d> Z>e j./d?dej?d@dAd gdBdC Z@e j./d1e8dDgdEdF ZAe j./dGdHdIej?fgdJdK ZBdLdM ZCe j./dNejDejEgdOdP ZFe j./d?dej?d@dAd gdQdR ZGe j./d1e8dDgdSdT ZHe j./dUdVgdWggdVgej?gggdXdY ZIdZd[ ZJd\d] ZKd^d_ ZLd`da ZMdbdc ZNe j./dddedfdgdhdigdjdk ZOe j./dlde e e e gdmdn ZPdodp ZQdqdr ZRdsdt ZSe j./ddddgdudv ZTdwdx ZUdydz ZVd{d| ZWe j./d}d~dgdd ZXdd ZYdd ZZdd Z[e j.j/dd de\d gd~ dgd~ gfdde\ej] gd~ ej]gd~ gfej] ej]e\ej] gd~ ej]gd~ gfdddgdddge\dddgdddggfdej] dgddej]ge\dej] dgddej]ggfgdddddgddd Z^e j./ddej]ej] dfddgddd gdfgdd Z_e j.j/dddgej] ej]gfddgdgd dgd gfgddgddd Z`e j./dddgdd Zae j./dddVejbjcdVdge j./dddVejbjcdVdgdd Zde j./de\ddVgdVdWgge\ddVgdVdggddddfej\ddgddgge9dej\ddgddgge9di dfgdd Zee j./dej?ejfej\fd ejgej\fdejgej\fej?ejfejhfdejgejhfej?ejfejDfdejgejDfej?ejfejifdejgejifej?ejfejjfdejgejjfej?ejfejkfdejgejkfge j./ddd~e\d dVdWgfdd~e\d dVdWgfgdd Zle j./dejhejDejiejjejkgdd Zme j./ddddge j./dej?ej\fd ej\fej?ejhfej?ejDfej?ejifej?ejjfgddń ZnddǄ Zoe j./dej\ddgddgge8ddej\ddddgddddgge8dfe\ej?dIgdIej?ggej?e\dIdIddgdIdIddggfej\ej?dgdej?gge8dej?ej\ddddgddddgge8dfej\ddgddgge8ddej\ddddgddddgge8dfgddʄ Zpe j./deege j./ddej?dfdgddф Zqddӄ ZrddՄ Zse j./deegddׄ Zte j./dejhejDejiejjejkgddل Zue j./dddgddބ Zve j./ddgd~ddWd dVgfdhdVd dWdd~gfgdd Zwe j./ddej?gdd Zxe j./ddej?gdd Zye j./dddddge8ddWfddddge8ddVfdddge8ddWfddddge8ddWfddVdWd~gezddWfdVdVdVdWgezddVfddddVgezddWfdVdVdVdgezddWfgdd Z{e j./ddddd gdd Z|e j./dddgdd Z}dd Z~dd Zdd Ze j./dejejfgdd Ze j./d ddge j./dddgdd Ze j./d ddge j./ddddge j./dddgdd ZdS (      Nsparse)kstest)_convert_container)assert_allclose)assert_allclose_dense_sparse)assert_array_equal)assert_array_almost_equal)enable_iterative_imputer)load_diabetes)MissingIndicator)SimpleImputerIterativeImputer
KNNImputer)DummyRegressor)BayesianRidgeARDRegressionRidgeCV)Pipeline)
make_union)GridSearchCV)tree)_sparse_random_matrix)ConvergenceWarning)_most_frequentc                 C   s   t | | | j|jkstd S N)r   dtypeAssertionErrorxy r!   Y/home/sam/Atlas/atlas_env/lib/python3.8/site-packages/sklearn/impute/tests/test_impute.py"_assert_array_equal_and_same_dtype!   s    
r#   c                 C   s   t | | | j|jkstd S r   )r   r   r   r   r!   r!   r"   _assert_allclose_and_same_dtype&   s    
r$   c           	      C   s   d||f }t }| jjdks(|jjdkr,t}t||d}|| |  }||j||	dd ||||	dd t||d}|t
|  |t
|  }t
|r| }||j||	dd ||||	dd dS )zUtility function for testing imputation for a given strategy.

    Test with dense and sparse arrays

    Check that:
        - the statistics (mean, median, mode) are correct
        - the missing values are imputed correctlyz<Parameters: strategy = %s, missing_values = %s, sparse = {0}fmissing_valuesstrategyF)err_msgTN)r   r   kindr	   r   fit	transformcopyZstatistics_formatr   
csc_matrixissparsetoarray)	XX_truer(   
statisticsr'   r)   Z	assert_aeimputerX_transr!   r!   r"   _check_statistics+   s$    	
r7   r(   meanmedianmost_frequentconstantc                 C   s   t jdd}t j|d d d< t| d}|t|}|jdksFt	||}|jdks^t	t
| d}||}|jdkst	d S )N
      r(   )r<   r=   )initial_strategy)nprandomrandnnanr   fit_transformr   
csr_matrixshaper   r   )r(   r2   r5   	X_imputedZiterative_imputerr!   r!   r"   test_imputation_shapeP   s    



rH   c              	   C   st   t d}t j|d d df< t| dd}tjtdd || W 5 Q R X tjtdd |	| W 5 Q R X d S )N      r      r(   verboseThe 'verbose' parametermatchZSkipping)
r@   onesrC   r   pytestwarnsFutureWarningr+   UserWarningr,   r(   r2   r5   r!   r!   r"    test_imputation_deletion_warninga   s    
rX   c              	   C   s   t d}tj}tjddddgtd}|j||d|gd|d	d
gg|d}t| dd}t jt	dd |
| W 5 Q R X t|j| t jtdd || W 5 Q R X d S )Npandasabcdr   rL      r=   r<   columnsrM   rO   rP   z6Skipping features without any observed values: \['b'\])rS   importorskipr@   rC   arrayobject	DataFramer   rT   rU   r+   r   Zfeature_names_in_rV   r,   )r(   pdr'   feature_namesr2   r5   r!   r!   r"   .test_imputation_deletion_warning_feature_nameso   s$    


 rh   c              	   C   s   t d}d|d< t|}t| dd}tjtdd || W 5 Q R X ||	  tjtdd |
| W 5 Q R X d S )NrI   r   )r(   r'   zProvide a dense arrayrP   )r@   rR   r   r/   r   rS   raises
ValueErrorr+   r1   r,   rW   r!   r!   r"   test_imputation_error_sparse_0   s    

rk   c                 O   s8   t | dr| jnt| }|dkr&tjS tj| f||S Nsizer   )hasattrrm   lenr@   rC   r9   Zarrargskwargslengthr!   r!   r"   safe_median   s    rt   c                 O   s8   t | dr| jnt| }|dkr&tjS tj| f||S rl   )rn   rm   ro   r@   rC   r8   rp   r!   r!   r"   	safe_mean   s    ru   c               
   C   sv  t jd} d}d}|| || f}t |d }t d|d d }|dd d  |dd d< dt jdd fd	t jd
d fg}|D ]\}}}	t |}
t |}t |d }t|d D ]Z}|| d dk|| d  || d  }t|d ||  || ||   d}|d | | }|d | }t 	||}|| 
t|d |  }|	|||||< t |||f|
d d |f< d|krt |t 	|| || f|d d |f< n(t ||t 	|| |f|d d |f< t j||
d d |f  t j||d d |f  q|d	kr<t |jdd }nt |jdd }|d d |f }t|
|||| qd S )Nr   r<   rL   r_   r=   r8   c                 S   s   t t| |fS r   )ru   r@   hstackzvpr!   r!   r"   <lambda>       z-test_imputation_mean_median.<locals>.<lambda>r9   c                 S   s   t t| |fS r   )rt   r@   rv   rw   r!   r!   r"   r{      r|   )Zaxis)r@   rA   RandomStatezerosarangerC   emptyrangemaxrepeatZpermutationro   rv   shuffleisnananyallr7   )rngdimdecrF   r~   valuestestsr(   Ztest_missing_valuesZtrue_value_funr2   r3   Ztrue_statisticsjZnb_zerosZnb_missing_valuesZ	nb_valuesrx   rz   ry   Zcols_to_keepr!   r!   r"   test_imputation_mean_median   sJ    

(&
 
r   c                  C   s   t dt jt jgdt jt jgddt jgddt jgddt jgddt jgddt jgddt jgg } t dddgdddgdddgdddgddd	gddd
gdddgdddgg }ddddd	d
ddg}t| |d|t j d S )Nr   rK   r_   r=   g      g      @g      @g            ?r9   )r@   rc   rC   Z	transposer7   )r2   ZX_imputed_medianZstatistics_medianr!   r!   r"   $test_imputation_median_special_cases   s0    





r   r   c              	   C   s\   t jdddgdddgddd	gg|d
}d}tjt|d t| d}|| W 5 Q R X d S )NrZ   r[   rJ   r_   e   gh	   r^   6non-numeric data:
could not convert string to float: 'rP   r>   )r@   rc   rS   ri   rj   r   rD   )r(   r   r2   msgr5   r!   r!   r"   .test_imputation_mean_median_error_invalid_type  s
    &
r   typelist	dataframec              	   C   sn   dddgdddgddd	gg}|d
kr8t d}||}d}t jt|d t| d}|| W 5 Q R X d S )NrZ   r[   rJ   r_   r   r   r   r   r   r   rY   r   rP   r>   )rS   rb   re   ri   rj   r   rD   )r(   r   r2   rf   r   r5   r!   r!   r"   :test_imputation_mean_median_error_invalid_type_list_pandas  s    


r   USc              	   C   s   t jt jt jddgt jdt jdgt jddt jgt jdddgg|d}d}tjt|d	  t| d
}||| W 5 Q R X d S )NrZ   r%   r\   r]   r[   r   r^   z#SimpleImputer does not support datarP   r>   )	r@   rc   rC   rS   ri   rj   r   r+   r,   )r(   r   r2   r)   r5   r!   r!   r"   /test_imputation_const_mostf_error_invalid_types$  s    

r   c               	   C   sz   t ddddgddddgddddgddddgg} t dddgdddgdddgdddgg}t| |dt jdddgd d S )	Nr   r   rK   r=   rJ   rL      r:   )r@   rc   r7   rC   )r2   r3   r!   r!   r"   test_imputation_most_frequent9  s    



	r   markerZNAN c                 C   s   t j| | ddg| d| dg| dd| g| dddggtd}t jdddgdddgdddgdddggtd}t| dd	}|||}t|| d S )
NrZ   r%   r\   r]   r[   r   r^   r:   r&   )r@   rc   rd   r   r+   r,   r   r   r2   r3   r5   r6   r!   r!   r"   %test_imputation_most_frequent_objectsT  s&    





r   categoryc                 C   sr   t d}td}|j|| d}tjdddgdddgdddgd	ddggtd}td
d}|	|}t
|| d S )NrY   ,Cat1,Cat2,Cat3,Cat4
,i,x,
a,,y,
a,j,,
b,j,x,r^   rZ   ir   r   r    r[   r:   r>   rS   rb   ioStringIOZread_csvr@   rc   rd   r   rD   r   r   rf   r%   dfr3   r5   r6   r!   r!   r"   $test_imputation_most_frequent_pandasq  s    

"

r   zX_data, missing_value)rL   r         ?c              	   C   sN   t jd| td}||d< tjtdd t|ddd}|| W 5 Q R X d S )	NrI   r^   r   r   zimputing numericalrP   r;   r   r'   r(   
fill_value)r@   fullfloatrS   ri   rj   r   rD   )ZX_datamissing_valuer2   r5   r!   r!   r"   +test_imputation_constant_error_invalid_type  s      r   c               	   C   s   t ddddgddddgddddgdd	d
dgg} t d
ddd
gdd
dd
gddd
d
gdd	d
d
gg}tddd
d}|| }t|| d S )Nr   r=   rJ   r_   rK   r   r      r   r   r;   r   )r@   rc   r   rD   r   )r2   r3   r5   r6   r!   r!   r"    test_imputation_constant_integer  s
    22
r   array_constructorc              	   C   s   t t jddt jgdt jdt jgddt jt jgdddt jgg}t ddddgddddgddddgddddgg}| |}| |}tddd	}||}t|| d S )
Ng?r   333333??gffffff?      ?r   r;   )r(   r   )r@   rc   rC   r   rD   r   )r   r2   r3   r5   r6   r!   r!   r"   test_imputation_constant_float  s    	*
r   c                 C   s   t j| dd| gd| d| gdd| | gddd	| ggtd
}t jddddgddddgddddgddd	dggtd
}t| ddd}||}t|| d S )NrZ   r[   r\   r]   r   r%   r   r   r   r^   missingr;   r   )r@   rc   rd   r   rD   r   r   r!   r!   r"   test_imputation_constant_object  s.    









  
r   c                 C   sz   t d}td}|j|| d}tjddddgddddgdd	ddgd
d	ddggtd}tdd}|	|}t
|| d S )NrY   r   r^   r   r   r   rZ   r    r   r[   r;   r>   r   r   r!   r!   r"   test_imputation_constant_pandas  s    








r   r2   rL   r=   c                 C   sf   t  | }|jdkstt  }|dgdgg |jdks@t|dgtjgg |jdksbtd S )Nr   rL   r=   )r   r+   n_iter_r   r@   rC   r2   r5   r!   r!   r"   "test_iterative_imputer_one_feature  s    r   c                  C   st   t dddd} | jd }tdt|dfdtjddfg}d	d
ddgi}t dddd }t||}|| | d S )Nd   皙?)densityr   r5   r'   r   random_stateZimputer__strategyr8   r9   r:   rL   )	r   datar   r   r   ZDecisionTreeRegressorr1   r   r+   )r2   r'   Zpipeline
parametersYgsr!   r!   r"   $test_imputation_pipeline_grid_search  s    

r   c                  C   sv  t ddddd} |   }tdddd}|||}d|d	< t||krTt|  }t|j	d ddd}|||}d|j	d< t|j	|j	krt|   }tddd
d}|||}d|d	< t
|| |   }t|j	d dd
d}|||}d|j	d< t
|j	|j	 |  }t|j	d dd
d}|||}d|j	d< t|j	|j	krrtd S )NrK   g      ?r   r   r   r8   T)r'   r(   r-   r   r   F)r   r-   r1   r   r+   r,   r@   r   r   r   r	   Ztocsc)ZX_origr2   r5   Xtr!   r!   r"   test_imputation_copy  s4    



r   c                  C   s   t jd} d}d}t||d| d }|dk}t j||< tdd}||}t||j	
| tdd|}t |
||j	
|krtd|_t|
||j	
| d S )Nr   r   r<   r   r   )max_iterrK   )r@   rA   r}   r   r1   rC   r   rD   r   initial_imputer_r,   r+   r   r   r   )r   nr]   r2   Zmissing_flagr5   rG   r!   r!   r"   !test_iterative_imputer_zero_iters:  s    


 r   c                  C   sp   t jd} d}d}t||d| d }tdddd}|| || tdddd}|| || d S )	Nr   r   rJ   r   r   rL   )r'   r   rN   r=   )r@   rA   r}   r   r1   r   r+   r,   )r   r   r]   r2   r5   r!   r!   r"   test_iterative_imputer_verboseR  s    


r   c                  C   sB   d} d}t | |f}tddd}||}t||j| d S )Nr   rJ   r   rL   )r'   r   )r@   r~   r   rD   r   r   r,   )r   r]   r2   r5   rG   r!   r!   r"   "test_iterative_imputer_all_missing`  s    
r   imputation_orderrA   roman	ascending
descendingarabicc           
      C   sR  t jd}d}d}d}t||d|d }d|d d df< td|dd	d
ddd| |d
}|| dd |jD }t||j	 |j
kst| dkrt |d |d  t d|kstn| dkrt |d |d  t |d ddkstn^| dkr*|d |d  }||d d  }	||	ksNtn$d| krNt|||d  ksNtd S )Nr   r   r<   r=   r   r   rL   rK   FT)
r'   r   n_nearest_featuressample_posteriorskip_complete	min_value	max_valuerN   r   r   c                 S   s   g | ]
}|j qS r!   Zfeat_idx).0r   r!   r!   r"   
<listcomp>  s     z;test_iterative_imputer_imputation_order.<locals>.<listcomp>r   r   r   rA   ending)r@   rA   r}   r   r1   r   rD   imputation_sequence_ro   r   Zn_features_with_missing_r   r   r   )
r   r   r   r]   r   r2   r5   Zordered_idxZordered_idx_round_1Zordered_idx_round_2r!   r!   r"   'test_iterative_imputer_imputation_orderi  s>    
(.

r   	estimatorc           	      C   s   t jd}d}d}t||d|d }tdd| |d}|| g }|jD ]>}| d k	r`t| ntt	 }t
|j|szt|t|j qLtt|t|kstd S )Nr   r   r<   r   r   rL   )r'   r   r   r   )r@   rA   r}   r   r1   r   rD   r   r   r   
isinstancer   r   appendidro   set)	r   r   r   r]   r2   r5   hashestripletexpected_typer!   r!   r"   !test_iterative_imputer_estimators  s$       

r   c                  C   s   t jd} d}d}t||d| d }tdddd| d}||}tt ||dk d tt 	||dk d t||dk ||dk  d S )	Nr   r   r<   r   r   rL   皙?)r'   r   r   r   r   
r@   rA   r}   r   r1   r   rD   r   minr   r   r   r]   r2   r5   r   r!   r!   r"   test_iterative_imputer_clip  s        
r   c                  C   s   t jd} d}d}t||d| d }d|d d df< tdddd	dd
dd| d	}||}tt ||dk d tt 	||dk d
 t||dk ||dk  d S )Nr   r   r<   r   r   rL   r=   rK   Tr   rA   )	r'   r   r   r   r   r   rN   r   r   r   r   r!   r!   r"   %test_iterative_imputer_clip_truncnorm  s(    
r   c                     s   t jd} | jdd t j d d< tddd| d  t  fdd	td
D }t	|dksnt
t	|dks~t
| |  }}t|| | d\}}|dkr|d7 }t|| | d\}}|dk s|dkst
dd S )N*   )rK   rK   )rm   r   r   T)r   r   r   r   c                    s   g | ]}  d  d  qS )r   )r,   )r   _r   r!   r"   r     s     zEtest_iterative_imputer_truncated_normal_posterior.<locals>.<listcomp>r   Znormg-q=r   r   z&The posterior does appear to be normal)r@   rA   r}   normalrC   r   rD   rc   r   r   r   r8   Zstdr   )r   ZimputationsmusigmaZks_statisticZp_valuer!   r   r"   1test_iterative_imputer_truncated_normal_posterior  s&       
r   c                 C   s   t jd}d}d}|jdd||fd}|jdd||fd}d|d d df< d|d< tdd| |d|}td| d	|}t||d d df ||d d df  d S )
Nr   r   r<   rJ   )lowhighrm   rL   r   )r'   r   r?   r   r&   )	r@   rA   r}   randintr   r+   r   r   r,   )r(   r   r   r]   X_trainX_testr5   Zinitial_imputerr!   r!   r"   +test_iterative_imputer_missing_at_transform  s(        r  c                  C   s   t jd} t jd}d}d}t||d| d }tddd| d}|| ||}||}t |t	
t |ksttddd	d d
| d}tddd	d d
|d}	|| |	| ||}
||}|	|}t|
| t|
| d S )Nr   rL   r   r<   r   r   T)r'   r   r   r   Fr   )r'   r   r   r   r   r   )r@   rA   r}   r   r1   r   r+   r,   r8   rS   Zapproxr   r   )Zrng1Zrng2r   r]   r2   r5   Z
X_fitted_1Z
X_fitted_2imputer1imputer2ZX_fitted_1aZX_fitted_1br!   r!   r"   .test_iterative_imputer_transform_stochasticity  sL       


	





r  c                  C   s   t jd} | dd}t j|d d df< td| d}td| d}|||}||}t	|d d dd f | t	|| d S )Nr   r   r<   )r   r   rL   )
r@   rA   r}   randrC   r   r+   r,   rD   r   )r   r2   m1m2Zpred1Zpred2r!   r!   r"   !test_iterative_imputer_no_missing?  s    
r  c            	      C   s   t jd} d}| |d}| d|}t ||}| ||dk }| }t j||< tdd| d}||}t	||dd d S )	Nr   2   rL   r   rK   r   rN   r   g{Gz?atol)
r@   rA   r}   r	  dotr-   rC   r   rD   r   )	r   r]   ABr2   nan_mask	X_missingr5   X_filledr!   r!   r"   test_iterative_imputer_rank_oneM  s    

r  rankrJ   rK   c                 C   s   t jd}d}d}||| }|| |}t ||}|||dk }| }t j||< |d }|d | }	||d  }
||d  }tddd|d|	}|	|}t
|
|d	d
 d S )Nr   F   r   r=   rK   r   rL   )r   r   rN   r   r   r  )r@   rA   r}   r	  r  r-   rC   r   r+   r,   r   )r  r   r   r]   r  r  r  r  r  r  X_test_filledr  r5   
X_test_estr!   r!   r"   )test_iterative_imputer_transform_recovery\  s.    
   
r  c               	   C   s  t jd} d}d}| ||}| ||}t |j}t|D ]R}t|D ]D}|d d || | f  |d d |f |d d |f  d 7  < qLq@| ||dk }| }	t j	|	|< |d }|	d | }
||d  }|	|d  }t
dd| d|
}||}t||dd	d
 d S )Nr   r   r<   r=   g      ?rL   r  gMbP?{Gz?)rtolr  )r@   rA   r}   rB   r~   rF   r   r	  r-   rC   r   r+   r,   r   )r   r   r]   r  r  r  r   r   r  r  r  r  r  r5   r  r!   r!   r"   &test_iterative_imputer_additive_matrixu  s&    D

r  c                  C   s   t jd} d}d}| |d}| d|}t ||}| ||dk }| }t j||< tdddd| d	}||}	t	|j
||j kstt|jdd| d
}||}
t|	|
dd tdddd| d	}|| |j|jkstd S )Nr   r  rK   rL   r   r   r  F)r   Ztolr   rN   r   )r   r   rN   r   gHz>r  )r@   rA   r}   r	  r  r-   rC   r   rD   ro   r   r   r   r   r+   r   )r   r   r]   r  r  r2   r  r  r5   ZX_filled_100ZX_filled_earlyr!   r!   r"   %test_iterative_imputer_early_stopping  sF    
    
   
    
r   c            
   	   C   s   t dd\} }| j\}}d| d d df< tjd}d}t|D ]0}|jt|t|| dd}tj	| ||f< q@t
d	dd
}t  tdt || |}	W 5 Q R X tt|	rtd S )NT)Z
return_X_yrL   rJ   r   g333333?F)rm   replacerK   )r   r   error)r   rF   r@   rA   r}   r   choicer   intrC   r   warningscatch_warningssimplefilterRuntimeWarningrD   r   r   r   )
r2   r    Z	n_samples
n_featuresr   Zmissing_rateZfeatZ
sample_idxr5   ZX_fillr!   r!   r"   $test_iterative_imputer_catch_warning  s"    
 
 
r*  z$min_value, max_value, correct_outputr   r   r<      i,  ZscalarszNone-defaultinflistszlists-with-inf)Zidsc                 C   s   t jddd}t| |d}|| t|jt jrFt|j	t jsJt
|jjd |jd krv|j	jd |jd kszt
t|dd d f |j t|dd d f |j	 d S )Nr   r<   rJ   r   r   rL   )r@   rA   r}   rB   r   r+   r   Z
_min_valuendarrayZ
_max_valuer   rF   r   )r   r   Zcorrect_outputr2   r5   r!   r!   r"   )test_iterative_imputer_min_max_array_like  s    
 r0  zmin_value, max_value, err_msg)r   r   min_value >= max_value.r1  z_value' should be of shapec              	   C   s@   t jd}t| |d}tjt|d || W 5 Q R X d S )Nr<   rJ   r.  rP   )r@   rA   r   rS   ri   rj   r+   )r   r   r)   r2   r5   r!   r!   r"   *test_iterative_imputer_catch_min_max_error  s    r3  zmin_max_1, min_max_2ir_   zNone-vs-infzScalar-vs-vectorc              	   C   s   t t jdddgdt jt jdgddt jdgt jddt jgg}t t jdt jdgddt jt jgt jdddgg}t| d | d dd	}t|d |d dd	}|||}|||}t|d d df |d d df  d S )
Nr=   rL   r<   r   rJ   r_   rK   r   )r   r   r   )r@   rc   rC   r   r+   r,   r   )Z	min_max_1Z	min_max_2r  r  r  r  ZX_test_imputed1ZX_test_imputed2r!   r!   r"   4test_iterative_imputer_min_max_array_like_imputation  s.    *    r4  r   TFc              	   C   s   t jd}t ddddgddddgddddgdd	ddgg}t t jdd	dgt jd	ddgt jdddgg}td
| |d}|||}| rt|d d df t 	|d d df  n t|d d df dddgdd d S )Nr   rK   r=   rL   r<   r   rJ   r   r_   r8   )r?   r   r         g-C6?)r  )
r@   rA   r}   rc   rC   r   r+   r,   r   r8   )r   r   r  r  r5   r  r!   r!   r"   'test_iterative_imputer_skip_non_missing  s    2.  *r7  
rs_imputer)seedrs_estimatorc                 C   sH   G dd d}||d}t | d}td}|| |j|ksDtd S )Nc                   @   s$   e Zd Zdd Zdd Zdd ZdS )zCtest_iterative_imputer_dont_set_random_state.<locals>.ZeroEstimatorc                 S   s
   || _ d S r   r   )selfr   r!   r!   r"   __init__,  s    zLtest_iterative_imputer_dont_set_random_state.<locals>.ZeroEstimator.__init__c                 _   s   | S r   r!   )r;  rq   Zkgardsr!   r!   r"   r+   /  s    zGtest_iterative_imputer_dont_set_random_state.<locals>.ZeroEstimator.fitc                 S   s   t |jd S )Nr   )r@   r~   rF   )r;  r2   r!   r!   r"   predict2  s    zKtest_iterative_imputer_dont_set_random_state.<locals>.ZeroEstimator.predictN)__name__
__module____qualname__r<  r+   r=  r!   r!   r!   r"   ZeroEstimator+  s   rA  r   r2  )r   r@   r~   r+   r   r   )r8  r:  rA  r   r5   r  r!   r!   r"   ,test_iterative_imputer_dont_set_random_state(  s    




rB  zX_fit, X_trans, params, msg_errr   missing-onlyauto)featuresr   zBhave missing values in transform but have no missing values in fitrZ   r[   r\   r^   z1MissingIndicator does not support data with dtypec              	   C   sD   t dd}|jf | tjt|d || | W 5 Q R X d S )Nr   r   rP   )r   
set_paramsrS   ri   rj   r+   r,   )X_fitr6   paramsZmsg_err	indicatorr!   r!   r"   test_missing_indicator_error<  s    
rJ  zmissing_values, dtype, arr_typez,param_features, n_features, features_indicesr   c                 C   s  t | | dgdd| gg}t | | dgdddgg}t dddgdddgg}t dddgdddgg}	|||}|||}||}|	|}	t| |dd}
|
|}|
|}|jd |kst|jd |kstt|
j	| t
||d d |f  t
||	d d |f  |jtks&t|jtks6tt|t jsHtt|t jsZt|
jd	d
 |
|}|
|}|jtkst|jtkst|jdkst|jdkstt
| | t
| | d S )NrL   r_   r=   r6  r<   r   F)r'   rE  r   Tr   csc)r@   rc   astyper   rD   r,   rF   r   r   Z	features_r   r   boolr   r/  rF  r.   r1   )r'   arr_typer   Zparam_featuresr)  Zfeatures_indicesrG  r6   ZX_fit_expectedZX_trans_expectedrI  
X_fit_maskX_trans_maskZX_fit_mask_sparseZX_trans_mask_sparser!   r!   r"   test_missing_indicator_newT  sB    

  



rQ  rN  c              	   C   s   d}t ||dgd|dgg}t ||dgdddgg}| |}| |}t|d}tjtdd	 || W 5 Q R X || tjtdd	 || W 5 Q R X d S )
Nr   rL   r_   r=   r6  r<   r   z"Sparse input with missing_values=0rP   )r@   rc   r   rS   ri   rj   rD   r,   )rN  r'   rG  r6   ZX_fit_sparseZX_trans_sparserI  r!   r!   r"   5test_missing_indicator_raise_on_sparse_with_missing_0  s    

rR  param_sparsezmissing_values, arr_typec                 C   sL  t ||dgd|dgg}t ||dgdddgg}| |t j}| |t j}t||d}||}||}|dkr|jdkst|jdkstn|d	kr|d
krt	|t j
stt	|t j
stn||dkrt	|t j
stt	|t j
stnRt|r$|jdkst|jdksHtn$t	|t j
s6tt	|t j
sHtd S )NrL   r_   r=   r6  r<   )r'   r   TrK  rD  r   F)r@   rc   rL  float64r   rD   r,   r.   r   r   r/  r   r0   )rN  r'   rS  rG  r6   rI  rO  rP  r!   r!   r"   #test_missing_indicator_sparse_param  s*    

rU  c                  C   sX   t jdddgdddggtd} tddd}|| }t|t dddgdddgg d S )	NrZ   r[   r\   r^   r   )r'   rE  TF)r@   rc   rd   r   rD   r   )r2   rI  r6   r!   r!   r"   test_missing_indicator_string  s    
rV  zX, missing_values, X_trans_expc                 C   s0   t t|ddt|d}|| }t|| d S )Nr:   r&   r   )r   r   r   rD   r   )r2   r'   ZX_trans_expZtransr6   r!   r!   r"   #test_missing_indicator_with_imputer  s    

rW  imputer_constructorz.imputer_missing_values, missing_value, err_msgNaNzInput X contains NaN)z-1r   z(types are expected to be both numerical.c              	   C   sR   t jd}|dd}||d< | |d}tjt|d || W 5 Q R X d S )Nr   r<   r   r   rP   )r@   rA   r}   rB   rS   ri   rj   rD   )rX  Zimputer_missing_valuesr   r)   r   r2   r5   r!   r!   r"   (test_inconsistent_dtype_X_missing_values  s    
rZ  c                  C   sB   t ddgddgg} tddd}|| }|jd dks>td S )NrL   rC  r   rE  r'   r   )r@   rc   r   rD   rF   r   r2   mir   r!   r!   r"   !test_missing_indicator_no_missing  s    
r^  c                  C   sP   t dddgdddgdddgg} tddd}|| }| | ksLtd S )Nr   rL   r=   r   r[  )r   rE   r   rD   Zgetnnzsumr   r\  r!   r!   r"   /test_missing_indicator_sparse_no_explicit_zeros)  s    "
r`  c                 C   s8   t ddgddgg}|  }|| |jd ks4td S )NrL   )r@   rc   r+   Z
indicator_r   )rX  r2   r5   r!   r!   r"   test_imputer_without_indicator4  s    
ra  c                 C   s   | t jddgdt jdgddt jgdddgg}t ddd	dd
d
gdddd
dd
gddd	d
d
dgdddd
d
d
gg}tt jdd}||}t|st|j|jkstt	|
 | d S )NrL   rK   r=   r   rJ   r         @r   g      @g               @g      @g      "@T)r'   add_indicator)r@   rC   rc   r   rD   r   r0   r   rF   r   r1   )rN  ZX_sparser3   r5   r6   r!   r!   r"   2test_simple_imputation_add_indicator_sparse_matrix=  s    .	
re  zstrategy, expected)r:   r[   )r;   r   c                 C   sN   ddgdt jgg}t jddgd|ggtd}t| d}||}t|| d S )NrZ   r[   r\   r^   r>   )r@   rC   rc   rd   r   rD   r   )r(   expectedr2   r3   r5   r6   r!   r!   r"   "test_simple_imputation_string_listZ  s
    

rg  zorder, idx_orderc              	   C   s   t jd}|dd}t j|d ddf< t j|d ddf< t j|d dd	f< t j|d d
df< tt6 td| dd	|}dd |j
D }||kstW 5 Q R X d S )Nr   r   rK   r  rL      r      r=   r<   r_   )r   r   r   c                 S   s   g | ]
}|j qS r!   r   )r   r   r!   r!   r"   r   y  s     z)test_imputation_order.<locals>.<listcomp>)r@   rA   r}   r	  rC   rS   rT   r   r   r+   r   r   )orderZ	idx_orderr   r2   Ztrsidxr!   r!   r"   test_imputation_orderh  s    rl  r   c              	   C   sD  t d| ddgddddgdd| dgddd	| gg}t ddd
dgd
d| dgd| ddgddd
| gg}t d| ddg| d| | gd
| d| g| d| dgg}t ddddg| d
| dgd
dddg| d| d
gg}t| ddd}||}||}||}||}	t|| t|	| ||fD ]$}
||
}||}t||
 qd S )Nr   rJ   r   r_   rK   r   r   r   r   r=   rL   r8   T)r'   r(   rd  )r@   rc   r   rD   inverse_transformr,   r   )r   X_1ZX_2ZX_3ZX_4r5   	X_1_transZX_1_inv_transZ	X_2_transZX_2_inv_transr2   r6   ZX_inv_transr!   r!   r"   (test_simple_imputation_inverse_transform}  sV    



	



	



	



	  







rp  c              	   C   sz   t d| ddgddddgdd| dgddd	| gg}t| d
d}||}tjtd|j dd || W 5 Q R X d S )Nr   rJ   r   r_   rK   r   r   r   r   r8   r&   zGot 'add_indicator='rP   )	r@   rc   r   rD   rS   ri   rj   rd  rm  )r   rn  r5   ro  r!   r!   r"   3test_simple_imputation_inverse_transform_exceptions  s    



	
 rr  z)expected,array,dtype,extra_value,n_repeatextra_valueZmost_frequent_valuevaluer   Zmin_valuevaluerx   ri  c                 C   s"   | t tj||d||kstd S )Nr^   )r   r@   rc   r   )rf  rc   r   rs  Zn_repeatr!   r!   r"   test_most_frequent  s
      ru  r?   c                 C   sp   t dt jdgdt jt jgg}t| dd}||}t|dddf d ||}t|dddf d dS )zCheck the behaviour of the iterative imputer with different initial strategy
    and keeping empty features (i.e. features containing only missing values).
    rL   r=   rJ   T)r?   keep_empty_featuresNr   )r@   rc   rC   r   rD   r   r,   )r?   r2   r5   rG   r!   r!   r"   *test_iterative_imputer_keep_empty_features  s      

rw  rv  c                 C   s   t dt jdgdt jt jgg}t| d}dD ]`}t|||}| rl|j|jksTtt|dddf d q.|j|jd |jd d fks.tq.dS )z>Check the behaviour of `keep_empty_features` for `KNNImputer`.rL   r=   rJ   )rv  rD   r,   Nr   )r@   rc   rC   r   getattrrF   r   r   )rv  r2   r5   methodrG   r!   r!   r"   $test_knn_imputer_keep_empty_features  s     
r{  c                  C   s  t d} | d| jdd dgddi}t| jddd	}t||tj	dgdgdggt
d | d| jddd
gddi}tddd}t||tj	dgdgd
ggt
d | d| jdd dgddi}t| jddd	}t||tj	dgdgdggdd ttjddd	}t||tj	dgdgdggdd | d| jdd ddgddi}t| jdd}t||tj	dgdgdgdggdd | d| jdd dgddi}t| jdd}t||tj	dgdgdggdd | d| jdd dgddi}t| jddd	}t||tj	dgdgdggdd | d| jdd ddgddi}t| jdd}t||tj	dgdgdgdggdd d S )NrY   featureabcdestringr^   r;   nar   Zfghok)r   r(   rL   rJ   ZInt64r   rT  r=   r9   r&   r8   r   r   rb  g       rc  )rS   rb   re   ZSeriesr   ZNAr#   rD   r@   rc   rd   r$   rC   )rf   r   r5   r!   r!   r"   test_simple_impute_pd_na  s`    
         r  c                  C   sj   t d} tj}| j||d|gd|ddggdddd	gd
}t|d|}| }dddg}t|| dS )zDCheck that missing indicator return the feature names with a prefix.rY   rL   r_   r=   r<   rZ   r[   r\   r]   r`   r   Zmissingindicator_aZmissingindicator_bZmissingindicator_dN)	rS   rb   r@   rC   re   r   r+   Zget_feature_names_outr   )rf   r'   r2   rI  rg   Zexpected_namesr!   r!   r"   (test_missing_indicator_feature_names_outC  s    




r  c                  C   s\   ddgddgddgg} t dd| }|tjtjgg}|jtksHtt|ddgg dS )zkCheck transform uses object dtype when fitted on an object dtype.

    Non-regression test for #19572.
    rZ   r[   r\   r:   r>   N)	r   r+   r,   r@   rC   r   rd   r   r   )r2   Zimp_frequentr6   r!   r!   r"    test_imputer_lists_fit_transformV  s
    r  
dtype_testc                 C   sp   t jddt jgt jddgdddggt jd}t |}t jt jt jt jgg| d}||}|j| ksltdS )	zACheck transform preserves numeric dtype independent of fit dtype.r   g333333@r   g@r=   rL   r^   N)	r@   asarrayrC   rT  r   r+   r,   r   r   )r  r2   impr  r6   r!   r!   r"   .test_imputer_transform_preserves_numeric_dtypec  s     
r  
array_typerc   r   c                 C   s   t t jdgt jdgt jdgg}t|| }d}td||d}dD ]V}t|||}|j|jksdt| dkr~|d	d	d
f jn|d	d	d
f }t	|| qBd	S )zCheck the behaviour of `keep_empty_features` with `strategy='constant'.
    For backward compatibility, a column full of missing values will always be
    fill and never dropped.
    r=   rJ   r   r<   r;   )r(   r   rv  rx  r   Nr   
r@   rc   rC   r   r   ry  rF   r   r  r   )r  rv  r2   r   r5   rz  rG   constant_featurer!   r!   r"   0test_simple_imputer_constant_keep_empty_featuresp  s    "
(r  c                 C   s   t t jdgt jdgt jdgg}t||}t| |d}dD ]~}t|||}|r|j|jksbt|dkr||dddf jn|dddf }t	|d q<|j|jd |jd	 d	 fks<tq<dS )
zYCheck the behaviour of `keep_empty_features` with all strategies but
    'constant'.
    r=   rJ   r   )r(   rv  rx  r   Nr   rL   r  )r(   r  rv  r2   r5   rz  rG   r  r!   r!   r"   'test_simple_imputer_keep_empty_features  s    "
(r  )rS   r%  numpyr@   Zscipyr   Zscipy.statsr   r   Zsklearn.utils._testingr   r   r   r   r	   Zsklearn.experimentalr
   Zsklearn.datasetsr   Zsklearn.imputer   r   r   r   Zsklearn.dummyr   Zsklearn.linear_modelr   r   r   Zsklearn.pipeliner   r   Zsklearn.model_selectionr   Zsklearnr   Zsklearn.random_projectionr   Zsklearn.exceptionsr   Zsklearn.impute._baser   r#   r$   r7   markZparametrizerH   rX   rh   rk   rt   ru   r   r   rd   strr   r   r   r   r   rC   r   r   r   r   rE   r  r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r  r  r   r*  rc   r,  r0  r3  r4  r7  rA   r}   rB  rJ  rT  Zint32r/   Z
coo_matrixZ
lil_matrixZ
bsr_matrixrQ  rR  rU  rV  rW  rZ  r^  r`  ra  re  rg  rl  rp  rr  r$  ru  rw  r{  r  r  r  Zfloat32r  r  r  r!   r!   r!   r"   <module>   s,  %



C 





"
+	 
% 
!
2
$**


	0


*,






   
	



 
"

9

 


<
