U
    md                    @   sT  d dl Z d dlZd dlZd dlmZ d dlZd dlmZm	Z	 d dlm
Z
mZ d dlmZ d dlZd dlmZ d dlmZmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dlmZ d dlmZ d dlmZmZ d dlm Z m!Z! d dl"m#Z# d dlm$Z$ d dl%m&Z& d dlm'Z' d dl"m(Z( d dl)m*Z* d dl+m,Z,m-Z-m.Z/m0Z1 ej23dZ4ee/d dZ.ee1d dZ0dZ5dd gd dgddggZ6e7e6Z8d ddgZ9ddd gZ:e Z;dd Z<dd  Z=d!d" Z>e(d#d$ Z?d%d& Z@ej2Ad'e.eBe;jCd(d)d*e.eBe;jCd+d,d*e.eBe;jCd-d,d*e.eBe;jCd.d/d)d0d1e.eBe;jCd2d/d)d0d1e.eBe;jCd3d)d*gd4d5 ZDej2Ad6e.e0gd7d8 ZEej2Ad9d+d-d.d2gd:d; ZFd<d= ZGd>d? ZHd@dA ZIdBdC ZJdDdE ZKdFdG ZLdHdI ZMdJdK ZNdLdM ZOej2AdNdOdPgfdQdRdSgfdTdRdSgfdUdPgfdVdRdSgfgdWdX ZPdYdZ ZQd[d\ ZRd]d^ ZSd_d` ZTdadb ZUej2Adcdddedfdddedgdhgej2Adidcdjgdkdl ZVdmdn ZWdodp ZXdqdr ZYdsdt ZZdudv Z[dwdx Z\dydz Z]d{d| Z^d}d~ Z_dd Z`ej2Add0gej2Adddgdd Zadd Zbej2Adecddej2Add)d,gej2Adddddddgdd Zdej2Ad9e5dd Zeej2Ad9efege5egd(g ej2Addej2Addej2Add)d,gdd Zhdd Ziej2Add)d,gej2Ad9d(d-d3d2gej2Adddgdd Zjdd Zkdd Zlej2Adddddddddgej2Adddgdd Zmej2Adddddgdd Znej2Adeodddej2Addddgdgdd Zpej2Adddd Zqdd Zrej2Addej2Adddd ZsddÄ Ztddń Zuej2Adeodddej2AddddgdgddǄ ZvddɄ Zwej2jAde.d dd̍e0d ddddd΍gddЄ dэej2Ad9e5ddӄ Zxej2Ad9efege5egd(g ddՄ Zyej2Adddddd؜ddddd؜ddddd؜gddۄ Zzdd݄ Z{ej2Adddgdd߄ Z|ej2Add)d,dgej2Adiddddhdjgdd Z}ej2Ad9e5dd Z~dd Zdd ZdS )    N)partial)assert_allcloseassert_almost_equal)assert_array_almost_equalassert_array_equal)sparse)clone)	load_irismake_classification)log_loss)
get_scorer)StratifiedKFold)GridSearchCV)train_test_split)cross_val_score)LabelEncoderStandardScaler)compute_class_weight	_IS_32BIT)ignore_warnings)shuffle)SGDClassifier)scale)skip_if_no_parallel)ConvergenceWarning)_log_reg_scoring_path_logistic_regression_pathLogisticRegressionLogisticRegressionCVz6error::sklearn.exceptions.ConvergenceWarning:sklearn.*random_state)lbfgs	liblinear	newton-cgnewton-choleskysagsaga      c                 C   s   t |}t|}|jd }| |||}t| j| |j|fksJtt|| | 	|}|j||fksptt
|jddt| t|jdd| dS )z;Check that the model is able to fit the classification datar   r(   ZaxisN)lennpuniqueshapefitpredictr   classes_AssertionErrorpredict_probar   sumonesargmax)clfXy	n_samplesclasses	n_classesZ	predictedprobabilities r>   a/home/sam/Atlas/atlas_env/lib/python3.8/site-packages/sklearn/linear_model/tests/test_logistic.pycheck_predictions4   s    



r@   c                   C   sx   t tddtt t tddtt t tdddtt t tdddtt t tdddtt t tdddtt d S )Nr   r   d   )Cr    F)fit_interceptr    )r@   r   r8   Y1X_spr>   r>   r>   r?   test_predict_2_classesF   s    rF   c                  C   s   G dd d} |  }ddddg}d}t |||d}tdd	\}}||| |jd |d ksbt|j|t| ksxtd|_||||}||j	d kst|jdkstd S )
Nc                   @   s   e Zd Zdd ZdddZdS )z0test_logistic_cv_mock_scorer.<locals>.MockScorerc                 S   s   d| _ ddddg| _d S )Nr   皙?g?皙?      ?)callsscores)selfr>   r>   r?   __init__U   s    z9test_logistic_cv_mock_scorer.<locals>.MockScorer.__init__Nc                 S   s(   | j | jt| j   }|  jd7  _|S )Nr(   )rK   rJ   r+   )rL   modelr8   r9   sample_weightscorer>   r>   r?   __call__Y   s    z9test_logistic_cv_mock_scorer.<locals>.MockScorer.__call__)N)__name__
__module____qualname__rM   rQ   r>   r>   r>   r?   
MockScorerT   s   rU   r(   r)         )Csscoringcvr   r   )
r   r
   r/   C_r2   rJ   r+   rP   r0   rK   )rU   Zmock_scorerrX   rZ   lrr8   r9   Zcustom_scorer>   r>   r?   test_logistic_cv_mock_scorerS   s    
r]   c               	   C   sT   t jj\} }t jt j }tddd}d}tjt|d |	t j| W 5 Q R X d S )Nr"   r)   )solvern_jobsz\'n_jobs' > 1 does not have any effect when 'solver' is set to 'liblinear'. Got 'n_jobs' = 2.match)
irisdatar.   target_namestargetr   pytestwarnsUserWarningr/   )r:   
n_featuresre   r\   warning_messager>   r>   r?   test_lr_liblinear_warningt   s    rk   c                   C   s(   t tddtt t tddtt d S )N
   )rB   )r@   r   r8   Y2rE   r>   r>   r>   r?   test_predict_3_classes   s    rn   r7   r"   ovr)rB   r^   multi_classr!   multinomialr#   r%   {Gz?*   )rB   r^   tolrp   r    r&   r$   c              	   C   s   t jj\}}t jt j }| jdkrRt   tdt	 | 
t j| W 5 Q R X n| 
t j| tt|| j | t j}t||kdkst| t j}t|jddt| t j|jdd }t||kdkstdS )zTest logistic regression with the iris dataset.

    Test that both multinomial and OvR solvers handle multiclass data correctly and
    give good accuracy score (>0.95) for the training data.
    r!   ignoreffffff?r(   r*   N)rb   rc   r.   rd   re   r^   warningscatch_warningssimplefilterr   r/   r   r,   r-   r1   r0   meanr2   r3   r   r4   r5   r6   )r7   r:   ri   re   predr=   r>   r>   r?   test_predict_iris   s    

r|   LRc              
   C   sl  t jt j }}dD ]B}d| d}| |dd}tjt|d ||| W 5 Q R X qdD ]@}d| }| |d	d
d}tjt|d ||| W 5 Q R X qZdD ]@}d| }| |dd
d}tjt|d ||| W 5 Q R X qdD ]@}d|}| |dd}tjt|d ||| W 5 Q R X q| tkrhd}| ddd}tjt|d ||| W 5 Q R X d S )Nr"   r$   zSolver z( does not support a multinomial backend.rq   r^   rp   r`   )r!   r#   r$   r%   z1Solver %s supports only 'l2' or 'none' penalties,l1ro   )r^   penaltyrp   )r!   r#   r$   r%   r&   z1Solver %s supports only dual=False, got dual=TrueT)r^   dualrp   )r"   z>Only 'saga' solver supports elasticnet penalty, got solver={}.
elasticnet)r^   r   z8penalty='none' is not supported for the liblinear solvernoner"   )r   r^   )	rb   rc   re   rf   raises
ValueErrorr/   formatr   )r}   r8   r9   r^   msgr\   r>   r>   r?   test_check_solver_option   s8    
r   r^   c                 C   s   t jdktj}tddg| }t| dddd}|t j| |j	j
dt jj
d fks^t|jj
d	ksntt|t j| t| ddd
d}|t j| |jtj|t jdd }t||kdkstd S )Nr   Zsetosaz
not-setosarq   rs     )r^   rp   r    max_iterr(   r(   F)r^   rp   r    rC   r*   ?)rb   re   astyper,   Zintparrayr   r/   rc   coef_r.   r2   
intercept_r   r0   r1   r6   Zpredict_log_probarz   )r^   re   r7   Zmlrr{   r>   r>   r?   test_multinomial_binary   s*          r   c                 C   s~   t | d\}}tddd| d}||| ||}||}t|t|t|   }tjd| |f }t|| d S )Nr   rq   r&   MbP?)rp   r^   rt   r    r(   )	r
   r   r/   decision_functionr3   r,   expZc_r   )Zglobal_random_seedr8   r9   r7   ZdecisionZprobaZexpected_proba_class_1Zexpected_probar>   r>   r?   %test_multinomial_binary_probabilities   s    

 r   c            
      C   s   t jj\} }t jt j }tt j}tdd||}||}|	  t
|jsVt||}t
|}||}|  ||}	t|| t|| t||	 d S Nr   r   )rb   rc   r.   rd   re   r   r   r/   r   Zsparsifyr   issparser   r2   Z
coo_matrixZdensifyr   )
r:   ri   re   r8   r7   Zpred_d_dZpred_s_dZsp_dataZpred_s_sZpred_d_sr>   r>   r?   test_sparsify  s    







r   c               	   C   s   t jd} | d}t |jd }d|d< tdd}|d d }tt	 |
t| W 5 Q R X tt	 |
||| d W 5 Q R X d S )Nr   )   rl   r   r'   )rV      )r,   randomRandomStateZrandom_sampler5   r.   r   rf   r   r   r/   r8   r0   )rngZX_Zy_r7   Zy_wrongr>   r>   r?   test_inconsistent_input*  s    

r   c                  C   sF   t dd} | tt d| jd d < d| jd d < t| td d S r   )r   r/   r8   rD   r   r   r   r   r7   r>   r>   r?   test_write_parameters>  s
    
r   c               	   C   sJ   t jtt jd} t j| d< tdd}tt |	| t
 W 5 Q R X d S )Ndtyper   r(   r   r   )r,   r   r8   float64nanr   rf   r   r   r/   rD   )ZXnanZlogisticr>   r>   r?   test_nanG  s
    

r   c                  C   sd  t jd} t | ddddg | ddf}dgd dgd  }t ddd}t}dD ]~}|t|||d	d
|dddd	\}}}t|D ]L\}}	t	|	d	d
|dddd}
|

|| |
j }t||| dd| d qq\dD ]~}dg}|t|||d|dddd\}}}t	|d dddd|d}
|

|| t |
j |
jg}t||d dd| d qd S )Nr   rA   r)   r(   r'   rW   rl   r%   r&   Fh㈵>  ro   )rX   rC   rt   r^   r   rp   r    )rB   rC   rt   r^   rp   r    r   zwith solver = %s)decimalerr_msg)r!   r#   r$   r"   r%   r&        @@ư>g     @)rX   rt   r^   intercept_scalingr    rp   )rB   rt   r   r    rp   r^   )r,   r   r   concatenaterandnlogspacer   r   	enumerater   r/   r   ravelr   r   )r   r8   r9   rX   fr^   coefs_irB   r\   Zlr_coefr>   r>   r?   test_consistency_pathR  s~    &	
   

   r   c               
   C   s   t jd} t | ddddg | ddf}dgd dgd  }dg}tt}t|||ddddd W 5 Q R X t	|dkst
|d jjd }d	|kst
d
|kst
d|kst
d|kst
d S )Nr   rA   r)   r(   r'   r           )rX   rt   r   r    verboselbfgs failed to convergez!Increase the number of iterationszscale the dataz%linear_model.html#logistic-regression)r,   r   r   r   r   rf   rg   r   r   r+   r2   messageargs)r   r8   r9   rX   recordZwarn_msgr>   r>   r?   .test_logistic_regression_path_convergence_fail  s(    &      r   c               	   C   s   t ddd\} }tdddddd}|| | tdddddd}|| | td	ddddd}|| | t|j|j d
}tjt|d t|j|j W 5 Q R X d S )N   r   r:   r    Tr   r"   ro   )r    r   rt   r^   rp      z)Arrays are not almost equal to 6 decimalsr`   )r
   r   r/   r   r   rf   r   r2   )r8   r9   Zlr1Zlr2Zlr3r   r>   r>   r?    test_liblinear_dual_random_state  s:    r   c            	      C   s*  d\} }t jd}|| |}t |d|| }|| 8 }||  }tdgddddd	}|	|| t
ddddd
}|	|| t|j|j t|jjd|f t|jddg t|jdkstt t|j }t|jddd|f t|jjd t t|j }t|jd d S )N)2   r   r   r         ?Fr"   ro   rV   )rX   rC   r^   rp   rZ   )rB   rC   r^   rp   r(   r'   r)   r   )r(   rV   r(   )r,   r   r   r   signdotrz   Zstdr   r/   r   r   r   r   r.   r1   r+   r2   asarraylistcoefs_paths_valuesCs_scores_)	r:   ri   r   X_refr9   lr_cvr\   coefs_pathsrK   r>   r>   r?   test_logistic_cv  s<           r   zscoring, multiclass_agg_listZaccuracy 	precisionZ_macroZ	_weightedf1Zneg_log_lossZrecallc                 C   s   t ddddd\}}tdtdd }}tddd	}| }d
D ]
}||= qD||| ||  |D ]L}	t| |	 }
tt||||fdg|
d|d d |
||| ||  qhd S )NrA   r   rV      )r:   r    r<   n_informativeP   r   rq   )rB   rp   )rB   r_   
warm_start)rX   rY   r)   )	r
   r,   aranger   
get_paramsr/   r   r   r   )rY   Zmulticlass_agg_listr8   r9   traintestr\   paramskeyZ	averagingZscorerr>   r>   r?   "test_logistic_cv_multinomial_score  s@       
    r   c            
      C   s  d\} }}t | ||ddd\}}t dddg|}t|d }td	d
}td	dd}td	d
}td	dd}	||| ||| ||| |	|| t|j	|j	 t
|jdddgkstt|j	|	j	 t
|jdddgkstt
|	jdddgkstt
t||dddgks,tt
t|	|dddgksPttddddd	d||}	t
t|	|ddgkstd S )N)r   r   rV   rV   r   )r:   ri   r<   r   r    barbazfoor(   rq   rp   )rp   rX   r)   )r   r   r   )class_weightrp   )r
   r   r/   Zinverse_transformr,   r   r   r   r   r   sortedr1   r2   r-   r0   )
r:   ri   r<   r   r9   Zy_strr\   r   Zlr_strZ	lr_cv_strr>   r>   r?   2test_multinomial_logistic_regression_string_inputs  sB    



$$
  r   c                  C   s|   t dddd\} }d| | dk < t| }t }|| | t }||| t|j|j t|j|j |j|jksxt	d S )Nr   r   r   r:   ri   r    r   r   )
r
   r   
csr_matrixr   r/   r   r   r   r[   r2   )r8   r9   csrr7   Zclfsr>   r>   r?   test_logistic_cv_sparse<  s    
r   c               	   C   st  t jt j } }| j\}}d}t|}t|| |}t|dd}|| | t|dd}|	 }	d|	|	dk< || |	 t
|jd |jd  t
|jdd  |j t
|jd tjd d f |j |jjd|fkstt|jdddg tt|j }
|
jd|d|d fkst|jjdks,ttt|j }|jd|dfksVtd	D ]}|d
krndnd}t|d|d|d
krdnddd}|dkrt| } || | || |}|| |}||kst|jj|jjkstt|jdddg tt|j }
|
jd|d|d fks0t|jjdksBttt|j }|jd|dfksZtqZd S )Nr)   ro   )rZ   rp   r(   r   rV   rl   )rl   r!   r#   r%   r&   r        rq   rs   r   rr   )r^   rp   r   r    rt   rZ   r!   )rb   rc   re   r.   r   r   splitr   r/   copyr   r   r   r   r,   Znewaxisr2   r   r1   r   r   r   r   r   rP   )r   re   r:   ri   Zn_cvrZ   Zprecomputed_foldsr7   clf1Ztarget_copyr   rK   r^   r   	clf_multiZmulti_scoreZ	ovr_scorer>   r>   r?   test_ovr_multinomial_irisJ  sX    
 

r   c                     sl   t dddd\ tdddd fd	d
tD } tj| ddD ]"\}}t| | j| | jdd qDdS )z)Test solvers converge to the same result.rl   r   r   )ri   r   r    Frs   ro   )rC   r    rp   c                    s(   i | ] }|t f d |i qS r^   )r   r/   .0r^   r8   r   r9   r>   r?   
<dictcomp>  s    z4test_logistic_regression_solvers.<locals>.<dictcomp>r)   rrV   r   Nr
   dictSOLVERS	itertoolscombinationsr   r   )
regressorssolver_1solver_2r>   r   r?    test_logistic_regression_solvers  s      r  c                     s   t dddddd\ d} td| dd	d
ddd fddtD }tj|ddD ]"\}}t|| j|| jdd qZdS )zATest solvers converge to the same result for multiclass problems.r   rl   rV   r   r:   ri   r   r<   r    Hz>Frs   ro   )rC   rt   r    rp   r   '  r   c              
      s2   i | ]*}|t f ||d d qS )rA   )r^   r   )r   getr/   r   r8   r   Zsolver_max_iterr9   r>   r?   r     s     
 z?test_logistic_regression_solvers_multiclass.<locals>.<dictcomp>r)   r   rW   r   Nr   )rt   r  r  r  r>   r	  r?   +test_logistic_regression_solvers_multiclass  s&        

  r
  weightrG   g?r   rI   )r   r(   r)   r   balancedc           	   	   C   s   t | }|dkr| }tddddd|dd\}}tddd|d	}tf d
di|}||| tttdg D ]L}tf d
|i|}|dkr|jdddd ||| t|j	|j	dd qndS )z+Test class_weight for LogisticRegressionCV.r  r   rV   r   )r:   ri   
n_repeatedr   n_redundantr<   r    r(   Fro   )rX   rC   rp   r   r^   r!   r   r   r  )rt   r   r    r   rtolN)
r+   r
   r   r   r/   setr   
set_paramsr   r   )	r  r   r<   r8   r9   r   Z	clf_lbfgsr^   r7   r>   r>   r?   (test_logistic_regressioncv_class_weights  s4    
	r  c                  C   s`  t dddddd\} }|d }ttfD ]z}dd	d
d}|tkrP|ddd dD ]b}|f d|i|}|f d|i|}|| | |j| |t|jd d t|j	|j	dd qT|f |}|j| ||d t
tt
d D ]X}|f ||dkrdndd|}	t  |	j| ||d W 5 Q R X t|j	|	j	dd qdD ]`}|f |dddd|}
|
| | |f d|i|}|j| ||d t|
j	|j	dd q@q&tdd	ddddddd
d}|| | tdd	dddd
d}	|	| || t|j	|	j	dd tdd	ddddd dd
d!}|| | tdd	dd dd
d"}	|	| || t|j	|	j	dd d S )#Nr   r   rV   r)   r   r  r(   rs   Fro   )r    rC   rp   )rX   rZ   )r!   r"   r^   rO   -C6?r  )r!   r&   r%   绽|=r   r^   rt   r   )r^   r   r"   r   )r^   rC   r   r   rt   r    rp   )r^   rC   r   rt   r    rp   rW   r   l2T)r^   rC   r   r   r   r    rp   )r^   rC   r   r   r    rp   )r
   r   r   updater/   r,   r5   r.   r   r   r  r   r   r   )r8   r9   rO   r}   kwr^   Zclf_sw_noneZclf_sw_onesZclf_sw_lbfgsZclf_swZ	clf_cw_12Z	clf_sw_12Zclf_cwr>   r>   r?   'test_logistic_regression_sample_weights  s        

 		r  c                 C   s*   t | }td|| d}tt||}|S )Nr  )r;   r9   )r,   r-   r   r   zip)r9   r;   r   class_weight_dictr>   r>   r?    _compute_class_weight_dictionary0  s    
r  c                  C   s  t tj} | dd d d f }tjdd  }d}t|}|D ]J}t|ddd}t|d|d}||| ||| t|j|jdd q<| ddd d f }tjdd }t|}t	t
t	d	 D ]J}t|d
dd}t|d
|d}||| ||| t|j|jdd qd S )N-   )r!   r#   rq   r  )r^   rp   r   rW   r   rA   r   ro   r   )r   rb   rc   re   r  r   r/   r   r   r  r   )ZX_irisr8   r9   Zsolversr  r^   r   Zclf2r>   r>   r?   &test_logistic_regression_class_weights8  sH    
        r   c               	   C   s  d\} }}t | |d|dd\}}tdd|}d}t|dd	}t|ddd
}||| ||| |jj||fkszt|jj||fkstdD ]}t|ddddd}t|dddddd}	||| |	|| |jj||fkst|	jj||fkstt|j|jdd t|j|	jdd t|j	|j	dd qdD ]J}t
|ddddgd}
|
|| t|
j|jdd t|
j	|j	dd q8d S )N)r   r   rV   rl   r   r  F)Z	with_meanr!   rq   r   )r^   rp   rC   )r%   r&   r#   rs   r   r  )r^   rp   r    r   rt   )r^   rp   r    r   rt   rC   rr   r  r   r   r   )r^   r   rt   rp   rX   g{Gz?)r
   r   Zfit_transformr   r/   r   r.   r2   r   r   r   )r:   ri   r<   r8   r9   r^   Zref_iZref_wZclf_iZclf_wZclf_pathr>   r>   r?   $test_logistic_regression_multinomial]  sl    

      r!  c                  C   sP   t dddd\} }tdddd}|| | td} t|| td d S )	Nr   r   r   Fr"   ro   )rC   r^   rp   )r   r   )r
   r   r/   r,   zerosr   r0   r8   r9   r7   r>   r>   r?   %test_liblinear_decision_function_zero  s
    
r$  c                  C   s4   t dddd\} }tddd}|t| | d S )Nrl   r   r   r   r"   ro   r   r
   r   r/   r   r   r#  r>   r>   r?   test_liblinear_logregcv_sparse  s    r&  c                  C   s4   t dddd\} }tddd}|t| | d S )Nrl   r   r   r   r&   rr   r  r%  r#  r>   r>   r?   test_saga_sparse  s    r'  c                  C   s(   t dd} | tt | jdks$td S )NF)rC   r   )r   r/   r8   rD   r   r2   r   r>   r>   r?   "test_logreg_intercept_scaling_zero  s    
r(  c               	   C   s   t jd} d}t|ddd\}}| j|dfd}t j|dfd	}t j|||fd
d}tddddddd}||| tdddddddd}||| t	|j
|j
 t	|j
ddd f t d t	|j
ddd f t d d S )Nrs   r   r   r   r   rV   sizer)   r.   r(   r*   r   r   r"   Fro   r  r   rB   r^   rC   rp   rt   r&   r   r   rB   r^   rC   rp   r   rt   r   )r,   r   r   r
   normalr5   r   r   r/   r   r   r"  )r   r:   r8   r9   X_noise
X_constantlr_liblinearlr_sagar>   r>   r?   test_logreg_l1  s8    	r4  c            	   	   C   s2  t jd} d}t|ddd\}}| jd|dfd}t j|d	fd
}t j|||fdd}d||dk < t|}t	ddddddd}|
|| t	dddddddd}|
|| t|j|j t|jddd f t d t|jddd f t d t	dddddddd}|
| | t|j|j d S )Nrs   r   r   r   r   rG   rV   )r   r*  r)   r+  r(   r*   r   r   r"   Fro   r  r,  r&   r   r-  r.  r   )r,   r   r   r
   r/  r"  r   r   r   r   r/   r   r   Ztoarray)	r   r:   r8   r9   r0  r1  r2  r3  Zlr_saga_denser>   r>   r?   test_logreg_l1_sparse_data  sR    
		r5  random_seedr   r   r  c                 C   sv   t dd| d\}}td|| ddd}tf dgd	d
|}||| tf ddi|}||| t|j|j d S )NrA   r   r   r&   r   -q=)r^   r   r    r   rt   r   T)rX   refitrB   )r
   r   r   r/   r   r   r   )r6  r   r8   r9   Zcommon_paramsr   r\   r>   r>   r?   !test_logistic_regression_cv_refit  s    r9  c                  C   s   t dddddd\} }tddd}|| | t||| }td	dd}|| | t||| }||ksrtt||| }t||| }||kstd S )
Nrl   r   r   rV   )r:   ri   r    r<   r   rq   r!   rp   r^   ro   )r
   r   r/   r   r3   r2   Z_predict_proba_lr)r8   r9   r   Zclf_multi_lossZclf_ovrZclf_ovr_lossZclf_wrong_lossr>   r>   r?   %test_logreg_predict_proba_multinomial8  s"        
r;  r   r   rp   zsolver, message)r#   z@newton-cg failed to converge. Increase the number of iterations.)r"   z@Liblinear failed to converge, increase the number of iterations.)r%   ?The max_iter was reached which means the coef_ did not converge)r&   r<  )r!   r   )r$   z6Newton solver did not converge after [0-9]* iterationsc              	   C   s   t jt j  }}d||dk< |dkr8|dkr8td |dkrR| dkrRtd t| d	|d|d
}tjt|d |	|| W 5 Q R X |j
d | kstd S )Nr   r)   r~   rq   z?'multinomial' is not supported by liblinear and newton-choleskyr$   r(   z/solver newton-cholesky might converge very fastV瞯<)r   rt   rp   r    r^   r`   )rb   rc   re   r   rf   skipr   rg   r   r/   n_iter_r2   )r   rp   r^   r   r8   y_binr\   r>   r>   r?   test_max_iterN  s     

rA  c           	      C   sl  t jt j }}| dkrt|}t|jd }|dks:t| }d||dk< d}d}t	dd| dd	}|
|| |jjd
ksttd| ||dd}|
|| |jjd||fkst|jdd
|| |jj|fkst|jdd
|| |jj|||fkst| dkrd S |jdd
|| |jjd
ks<t|jdd
|| |jjd||fkshtd S )Nr!   r   rV   r)   rW   rr   r   rs   )rt   rB   r^   r    r   )rt   r^   rX   rZ   r    r(   ro   r   r~   rq   )rb   rc   re   r   r,   r-   r.   r2   r   r   r/   r?  r   r  )	r^   r8   r9   r<   r@  Zn_CsZ	n_cv_foldr7   Zclf_cvr>   r>   r?   test_n_iterx  s>        
rB  r   )TFrC   c           
   	   C   s   t jt j }}| dkr"|dkr"d S td||| d|d}ttd* ||| |j}d|_||| W 5 Q R X t	
t	||j }d| |t|t|f }	|rd	|kst|	n|d	kst|	d S )
Nr$   rq   r  rs   )rt   rp   r   r^   r    rC   )categoryr(   zUWarm starting issue with %s solver in %s mode with fit_intercept=%s and warm_start=%s       @)rb   rc   re   r   r   r   r/   r   r   r,   r4   absstrr2   )
r^   r   rC   rp   r8   r9   r7   Zcoef_1Zcum_diffr   r>   r>   r?   test_warm_start  s0    rG  c                  C   s  t  } | j| j }}t|gd }t|gd }||dk }||dk d d }tdddd\}}t|}||f||ffD ]\}}dD ]}|jd }t	d	ddD ]l}	t
d
||	  dddd|ddd}
t
d
||	  dddd|ddd}|
|| ||| t|
j|jd qqqd S )NrV   r(   r)   r   r   r   r   )r   r  r'   r   r&   ro      Fr   )rB   r^   rp   r   rC   r   r    rt   r"   )r	   rc   re   r,   r   r
   r   r   r.   r   r   r/   r   r   )rb   r8   r9   ZX_binr@  ZX_sparseZy_sparser   r:   alphar&   r"   r>   r>   r?   test_saga_vs_liblinear  sN      




rJ  FTc                 C   s  | dkr"|dkr"t d|  d | dkr0tjntj}tttj}tttj}tttj}tttj}t	j
ttjd}t	j
ttjd}	d}
t| |d|
|d	}t|}||| |jj|kstt|}||| |jj|kstt|}||| |jjtjks$tt|}||	| |jjtjksLtd
|
 }tjdkrjtrjd}t|j|jtj|d | dkr|rd}t|j|j|d t|j|j|d d S )Nr~   rq   zSolver=z' does not support multinomial logistic.r"   r   gMb@?rs   )r^   rp   r    rt   rC   gQ@ntrr   atolr&   rG   )rf   r>  r,   r   Zfloat32r   r8   r   rD   r   r   r   r   r/   r   r   r2   osnamer   r   )r^   rp   rC   Z
out32_typeZX_32Zy_32ZX_64Zy_64ZX_sparse_32ZX_sparse_64Z
solver_tolZlr_templZlr_32Zlr_32_sparseZlr_64Zlr_64_sparserM  r>   r>   r?   test_dtype_match  sJ    		rP  c                  C   s   t jd} t | ddddg | ddf}t dgd dgd  }tddddd	}tddd
dd	}t||||	|}t
dD ]}||| qt||	|}t||dd d S )Nr   rA   r)   r(   r'   rq   r%   F)rp   r^   r   r    Tr   r   r  )r,   r   r   r   r   r   r   r   r/   r3   ranger   )r   r8   r9   Zlr_no_wsZlr_wsZlr_no_ws_lossr   Z
lr_ws_lossr>   r>   r?   test_warm_start_converge_LRN  s(    &      rR  c            
   
   C   s   t dd\} }d}d}t }dD ]2}t||dd|ddd	}|| | ||j q |\}}}	tj||dd
drtttj||	dd
drttj|	|dd
drtd S )Nr   r   rD  rI   )r   r   r  r&   r   rH  )r   rB   r^   r    l1_ratiort   r   rG   )r  rM  )	r
   r   r   r/   appendr   r,   allcloser2   )
r8   r9   rB   rS  Zcoeffsr   r\   Zelastic_net_coeffsZ	l1_coeffsZ	l2_coeffsr>   r>   r?   test_elastic_net_coeffsc  s(    	
rV  rB   r   rl   rA   r   g    .Azpenalty, l1_ratio)r   r(   )r  r   c                 C   s^   t dd\}}td| |dddd}t|| dddd}||| ||| t|j|j d S )Nr   r   r   r&   rr   )r   rB   rS  r^   r    rt   r   rB   r^   r    rt   )r
   r   r/   r   r   )rB   r   rS  r8   r9   lr_enetZlr_expectedr>   r>   r?   "test_elastic_net_l1_l2_equivalence  s&        rY  c                 C   s   t ddd\}}t||dd\}}}}dtdddi}td| ddd	d
}t||dd}	td| ddd	d
}
td| ddd	d
}|	|
|fD ]}||| q|	|||
||kst|	|||||kstd S )Nr   r   r   rS  r(   r   r   r&   rr   rW  T)r8  r   r  )	r
   r   r,   linspacer   r   r/   rP   r2   )rB   r8   r9   X_trainX_testy_trainy_test
param_gridZenet_clfgsZl1_clfZl2_clfr7   r>   r>   r?   test_elastic_net_vs_l1_l2  s:                ra  rW   rS  r   c              	      s   t dddddddd\ttddd d	d
}tddd d	d}| |  fdd}||||k std S )Nr   r)   r   rl   r   r:   r<   ri   r   r  r  r    r   r&   F)r   r^   r    rB   rS  rC   r  )r   r^   r    rB   rC   c                    sV   | j  } t|  }|tt| 7 }|d d t|| 7 }|S )Nr   rI   )r   r   r   r3   r,   r4   rE  r   )r\   ZcoefobjrB   r8   rS  r9   r>   r?   enet_objective  s
    
zEtest_LogisticRegression_elastic_net_objective.<locals>.enet_objective)r
   r   r   r/   r2   )rB   rS  rX  Zlr_l2rf  r>   re  r?   -test_LogisticRegression_elastic_net_objective  s:    
	    rg  )ro   rq   c           
   
   C   s   | dkrt dd\}}nt ddddd\}}td}tddd}td	d
d}td|d||d| dd}||| ||d}tddd| dd}t|||d}	|	|| |	j	d |j
d kst|	j	d |jd kstd S )Nro   r   r   rA   rV   r:   r<   r   r    r   r(   rW   r   r&   rr   r   rX   r^   rZ   	l1_ratiosr    rp   rt   rB   rS  r   r^   r    rp   rt   rZ   rS  rB   )r
   r   r,   rZ  r   r   r/   r   r   Zbest_params_	l1_ratio_r2   r[   )
rp   r8   r9   rZ   rk  rX   lrcvr_  r\   r`  r>   r>   r?   2test_LogisticRegressionCV_GridSearchCV_elastic_net  sD       


rq  c               
   C   s   t ddddd\} }t| |dd\}}}}td}tddd}tdd	d}td
|d||dddd}	|	|| ||d}
td
ddddd}t	||
|d}||| |	
||
|k dkst|	
||
|k dkstd S )NrA   rV   r   rh  r   r   r(   ri  rW   r   r&   ro   rr   rj  rl  rm  rn  rH   )r
   r   r   r,   rZ  r   r   r/   r   r   r0   rz   r2   )r8   r9   r[  r\  r]  r^  rZ   rk  rX   rp  r_  r\   r`  r>   r>   r?   6test_LogisticRegressionCV_GridSearchCV_elastic_net_ovr  sB       


 rr  )r  r   )ro   rq   autoc           	   
   C   s   d}d}t d|||dd\}}tddd}| dkrDtdd	d
}nd }t| |d|d|ddd}||| |jj|fks~t|j	j|fkst|j
j||fkstd S )NrV   r   rH  r   r:   r<   r   ri   r    ri  rW   r   r(   r)   r&   rr   F)r   rX   r^   rk  r    rp   rt   r8  )r
   r,   r   rZ  r   r/   r[   r.   r2   ro  r   )	r   rp   r<   ri   r8   r9   rX   rk  rp  r>   r>   r?   "test_LogisticRegressionCV_no_refit:  s6    

ru  c            
   
   C   s   d} d}t d| | |dd\}}tddd}tddd	}d	}td
|d||dddd}||| tt|j	 }|j
| ||j|j|d fksttt|j	 }	|	j
| ||j|jfkst|jj
| ||j|jfkstd S )NrV   r   rH  r   rt  ri  rW   r(   r)   r   r&   ro   rr   )r   rX   r^   rZ   rk  rp   r    rt   )r
   r,   r   rZ  r   r/   r   r   r   r   r.   r*  r2   r   r?  )
r<   ri   r8   r9   rX   rk  Zn_foldsrp  r   rK   r>   r>   r?   5test_LogisticRegressionCV_elasticnet_attribute_shapes_  sD    


rv  c               	   C   s8   d} t jt| d tddddtt W 5 Q R X d S )NzQl1_ratio parameter is only used when penalty is 'elasticnet'\. Got \(penalty=l1\)r`   r   r&   rI   )r   r^   rS  )rf   rg   rh   r   r/   r8   rD   )r   r>   r>   r?   test_l1_ratio_non_elasticnet  s    rw  c              
   C   s   d}t |ddddddd\}}t|}tdddd d	|d
|  | dd}tddddd|| dd}||| ||| t|j|jdd d S )Nr   r)   r   r   r(   rc  r   Fr   r   r   )r   r    rC   rt   r   rS  rI  Zlossr   r   r&   )r   r    rC   rt   r   rS  rB   r^   r   )r
   r   r   r   r/   r   r   )rB   rS  r:   r8   r9   Zsgdlogr>   r>   r?   test_elastic_net_versus_sgd  sD    
	

ry  c               	   C   s   t dddddddd\} }dddg}t| |d	|d
ddd\}}}tt t|d |d dd W 5 Q R X tt t|d |d dd W 5 Q R X tt t|d |d dd W 5 Q R X d S )NrH  rV   r)   r   r(   )r:   r<   r   r  Zn_clusters_per_classr    ri   r   r  r   r&   rq   )r   rX   r^   r    rp   r   )r
   r   rf   r   r2   r   )r8   r9   rX   r   r   r>   r>   r?   /test_logistic_regression_path_coefs_multinomial  s2    
	

  rz  estr   )r    r   rV   )r    rZ   rX   rt   r   c                 C   s   | j jS N)	__class__rR   )xr>   r>   r?   <lambda>      r  )Zidsc              	      sX   fdd}t tj}|d d d }|dd d }tjd d d }|dk}|||d|d}|||d|d}	t|j|	j t|||	| |||d|d}
|d	kr|||d|d}t|
j|j t|
||| nx|||d
|d}t|
j|j t|
||| t|j|||d
|djr2t	t|j|||d
|djrTt	d S )Nc                    s   t  jf || |S r|  )r   r  r/   )r8   r9   r  r{  r>   r?   r/     s    z6test_logistic_regression_multi_class_auto.<locals>.fitrl   r(   r   rs  r:  ro   r~   rq   )
r   rb   rc   re   r   r   r3   r,   rU  r2   )r{  r^   r/   Zscaled_datar8   X2Zy_multir@  Zest_auto_binZest_ovr_binZest_auto_multiZest_ovr_multiZest_multi_multir>   r  r?   )test_logistic_regression_multi_class_auto  s@    
  
r  c           	   	   C   s   t ddd\}}d}td | dd}tjt|d ||| W 5 Q R X td | dd}td	tj| dd
}||||}||||}t	|| d S )Nr   r   r   z&Setting penalty=None will ignore the CrW   )r   r^   rB   r`   )r   r^   r    r  )r   rB   r^   r    )
r
   r   rf   rg   rh   r/   r,   infr0   r   )	r^   r8   r9   r   r\   Zlr_noneZlr_l2_C_infZ	pred_noneZpred_l2_C_infr>   r>   r?   test_penalty_none  s       r  r   r   )r   r   rt   r   r7  c                 C   st  t jddgddgddgddgddgddgddgddgddgddgddgddgddgddgddgddggt dd}t jddddddddddddddddgt dd}t ||g}t |d| g}t jt|d d}d	|t|d < t|||d	d
\}}}tddd}|j	f |  t
|||}t
|j|||d}dD ],}	t||	|}
t||	|}t|
| qBd S )Nr(   rV   r)   rW   floatr   intr+  r   r   r"   rs   )r^   r    r  )r0   r3   r   )r,   r   r   ZvstackZhstackr5   r+   r   r   r  r   r/   getattrr   )r   r8   r9   r  y2rO   Zbase_clfZclf_no_weightZclf_with_weightmethodZX_clf_no_weightZX_clf_with_weightr>   r>   r?   /test_logisticregression_liblinear_sample_weight&  sJ    " r  c                  C   s   t ddd\} }tdd}ddg}ddd	g}td
d|||dddd}|| | |jd jdd}t|D ]^\}}t|D ]L\}	}
td
d||
dddd}t|| ||d }|||	f t	
|ksztqzqjd S )Nr   r   r   r   )Zn_splitsrG   r   r(   rl   r   r&      r   )r   r^   rk  rX   rZ   r    r   rt   r*   )r   r^   rB   rS  r    r   rt   rn  )r
   r   r   r/   r   rz   r   r   r   rf   approxr2   )r8   r9   rZ   rk  rX   rp  Zavg_scores_lrcvr   rB   jrS  r\   Zavg_score_lrr>   r>   r?   'test_scores_attribute_layout_elasticnet[  s:    



r  c                 C   s   t jj\}}t jt j }ttt jdd| d}tt j}||| t	|j
jddddd | r||jjddtjddd	k d
S )a|  Test that the multinomial classification is identifiable.

    A multinomial with c classes can be modeled with
    probability_k = exp(X@coef_k) / sum(exp(X@coef_l), l=1..c) for k=1..c.
    This is not identifiable, unless one chooses a further constraint.
    According to [1], the maximum of the L2 penalized likelihood automatically
    satisfies the symmetric constraint:
    sum(coef_k, k=1..c) = 0

    Further details can be found in [2].

    Reference
    ---------
    .. [1] :doi:`Zhu, Ji and Trevor J. Hastie. "Classification of gene microarrays by
           penalized logistic regression". Biostatistics 5 3 (2004): 427-43.
           <10.1093/biostatistics/kxg046>`

    .. [2] :arxiv:`Noah Simon and Jerome Friedman and Trevor Hastie. (2013)
           "A Blockwise Descent Algorithm for Group-penalized Multiresponse and
           Multinomial Regression". <1311.6529>`
    r!   rq   )rB   r^   rp   rC   r   r*   r  rL  r=  )rE  N)rb   rc   r.   rd   re   r   r+   r   r/   r   r   r4   r   rf   r  )rC   r:   ri   re   r7   ZX_scaledr>   r>   r?   (test_multinomial_identifiability_on_iris  s    
r  rs  r   g      $@c                 C   sf   t dd\}}t|}t|}d|d |d < | }td|d| d}|j|||d t|| d S )NT)Z
return_X_yr)   r   rH  )r    r   r   rp   r  )r	   r+   r,   r5   r   r   r/   r   )rp   r   r8   r9   ri   Wexpectedr7   r>   r>   r?   test_sample_weight_not_modified  s    
   r  c              	   C   s   t jdddd}dD ]}t||t||d qtjjd|jd d	}| d
krd}t	j
t|d t| d|| W 5 Q R X nt| d|| d S )Nr   rl   r   )r   )indicesZindptrZint64r)   r   r)  )r"   r%   r&   z0Only sparse matrices with 32-bit integer indicesr`   r   )r   Zrandsetattrr  r   r,   r   randintr.   rf   r   r   r   r/   )r^   r8   attrr9   r   r>   r>   r?   test_large_sparse_matrix  s    r  c               
   C   sb   t ddddddddggj} t d	d	d
d
d	d	d
d	g}| jd	 d	ksJttddd| | d S )NrI   g?g?g      ?rH   gHzG?rv   gffffff?r(   r   r#   T)r^   rC   )r,   r   Tr.   r2   r   r/   )r8   r9   r>   r>   r?   test_single_feature_newton_cg  s    r  c               	   C   sF   t jt j } tdd}d}tjt|d |t j|  W 5 Q R X d S )Nr   )r   zv`penalty='none'`has been deprecated in 1.2 and will be removed in 1.4. To keep the past behaviour, set `penalty=None`.r`   )	rb   rd   re   r   rf   rg   FutureWarningr/   rc   )re   r\   rj   r>   r>   r?   #test_warning_on_penalty_string_none  s    
r  )r   rN  rw   	functoolsr   numpyr,   Znumpy.testingr   r   r   r   Zscipyr   rf   Zsklearn.baser   Zsklearn.datasetsr	   r
   Zsklearn.metricsr   r   Zsklearn.model_selectionr   r   r   r   Zsklearn.preprocessingr   r   Zsklearn.utilsr   r   Zsklearn.utils._testingr   r   Zsklearn.linear_modelr   r   r   Zsklearn.exceptionsr   Zsklearn.linear_model._logisticr   r   r   ZLogisticRegressionDefaultr   ZLogisticRegressionCVDefaultmarkfilterwarningsZ
pytestmarkr   r8   r   rE   rD   rm   rb   r@   rF   r]   rk   rn   Zparametrizer+   rc   r|   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r
  r  r  r  r   r!  r$  r&  r'  r(  r4  r5  r9  r;  r   rA  rB  r   r  rG  rJ  rP  rR  rV  rY  ra  r   rg  rq  rr  ru  rv  rw  ry  rz  r  r  r  r  r  r  r  r  r  r>   r>   r>   r?   <module>   s  


!
        

*
	B$



)C W%?&5
1$/ 
H
)
--#+	( 
*

-,
*
