o
    ij                     @   s
  d Z ddlZddlZddlZddlmZ ddlZddlmZ ddl	m
Z
mZmZ ddlmZ ddlmZ eeZeG d	d
 d
ZG dd dZeG dd deZG dd dZG dd deZG dd deZG dd deZG dd deZG dd deeZdS )zJ
Callbacks to use with the Trainer class and customize the training loop.
    N)	dataclass)tqdm   )IntervalStrategySaveStrategy
has_length)TrainingArguments)loggingc                   @   s~  e Zd ZU dZdZedB ed< dZeed< dZ	eed< dZ
eed< dZeed	< dZeed
< dZedB ed< dZeed< dZeed< dZeed< dZeeeef  ed< dZedB ed< dZedB ed< dZedB ed< dZeed< dZeed< dZeed< dZedB ed< dZeeeeB eB eB f dB ed< dZed dB ed< dd Zdefdd Z e!defd!d"Z"d#d$ Z#d%d& Z$dS )'TrainerStatea  
    A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing
    and passed to the [`TrainerCallback`].

    <Tip>

    In all this class, one step is to be understood as one update step. When using gradient accumulation, one update
    step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update
    step requires going through *n* batches.

    </Tip>

    Args:
        epoch (`float`, *optional*):
            Only set during training, will represent the epoch the training is at (the decimal part being the
            percentage of the current epoch completed).
        global_step (`int`, *optional*, defaults to 0):
            During training, represents the number of update steps completed.
        max_steps (`int`, *optional*, defaults to 0):
            The number of update steps to do during the current training.
        logging_steps (`int`, *optional*, defaults to 500):
            Log every X updates steps
        eval_steps (`int`, *optional*):
            Run an evaluation every X steps.
        save_steps (`int`, *optional*, defaults to 500):
            Save checkpoint every X updates steps.
        train_batch_size (`int`, *optional*):
            The batch size for the training dataloader. Only needed when
            `auto_find_batch_size` has been used.
        num_input_tokens_seen (`int`, *optional*, defaults to 0):
            When tracking the inputs tokens, the number of tokens seen during training (number of input tokens, not the
            number of prediction tokens).
        total_flos (`float`, *optional*, defaults to 0):
            The total number of floating operations done by the model since the beginning of training (stored as floats
            to avoid overflow).
        log_history (`list[dict[str, float]]`, *optional*):
            The list of logs done since the beginning of training.
        best_metric (`float`, *optional*):
            When tracking the best model, the value of the best metric encountered so far.
        best_global_step (`int`, *optional*):
            When tracking the best model, the step at which the best metric was encountered.
            Used for setting `best_model_checkpoint`.
        best_model_checkpoint (`str`, *optional*):
            When tracking the best model, the value of the name of the checkpoint for the best model encountered so
            far.
        is_local_process_zero (`bool`, *optional*, defaults to `True`):
            Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
            several machines) main process.
        is_world_process_zero (`bool`, *optional*, defaults to `True`):
            Whether or not this process is the global main process (when training in a distributed fashion on several
            machines, this is only going to be `True` for one process).
        is_hyper_param_search (`bool`, *optional*, defaults to `False`):
            Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
            impact the way data will be logged in TensorBoard.
        stateful_callbacks (`list[StatefulTrainerCallback]`, *optional*):
            Callbacks attached to the `Trainer` that should have their states be saved or restored.
            Relevant callbacks should implement a `state` and `from_state` function.
    Nepochr   global_step	max_stepsi  logging_steps
eval_steps
save_stepstrain_batch_sizenum_train_epochsnum_input_tokens_seen
total_floslog_historybest_metricbest_global_stepbest_model_checkpointTis_local_process_zerois_world_process_zeroFis_hyper_param_search
trial_nametrial_paramsTrainerCallbackstateful_callbacksc                 C   s   | j d u rg | _ | jd u ri | _d S t| jtrd S i }| jD ]6}t|ts/tdt| |jj}||v rOt|| t	sE|| g||< || 
|  q| ||< q|| _d S )NzNAll callbacks passed to be saved must inherit `ExportableState`, but received )r   r   
isinstancedictExportableState	TypeErrortype	__class____name__listappendstate)selfr   callbackname r-   _/sda-disk/www/egybert/egybert_env/lib/python3.10/site-packages/transformers/trainer_callback.py__post_init__t   s&   





zTrainerState.__post_init__	json_pathc                 C   sX   t jt| dddd }t|ddd}|| W d   dS 1 s%w   Y  dS )	zDSave the content of this instance in JSON format inside `json_path`.   T)indent	sort_keys
wutf-8encodingN)jsondumpsdataclassesasdictopenwrite)r*   r0   json_stringfr-   r-   r.   save_to_json   s   "zTrainerState.save_to_jsonc                 C   sH   t |dd}| }W d   n1 sw   Y  | di t|S )z3Create an instance from the content of `json_path`.r6   r7   Nr-   )r=   readr9   loads)clsr0   r@   textr-   r-   r.   load_from_json   s   
zTrainerState.load_from_jsonc                 C   sN   dD ]"}t || d}|dur$|dk rt|| }t| | d| qdS )z
        Calculates and stores the absolute value for logging,
        eval, and save steps based on if it was a proportion
        or not.
        )r	   evalsave_stepsNr   )getattrmathceilsetattr)r*   argsr   	step_kind	num_stepsr-   r-   r.   compute_steps   s   zTrainerState.compute_stepsc                 C   sj   |j dur|jdur| |j| _d| _|dur#ddlm} ||| _|| _|| _| | _|	 | _	dS )zI
        Stores the initial training references needed in `self`
        Nr   )	hp_params)
hp_name_trialr   r   transformers.integrationsrR   r   r   r   r   )r*   trainerr   r   trialrR   r-   r-   r.   init_training_references   s   

z%TrainerState.init_training_references)%r&   
__module____qualname____doc__r   float__annotations__r   intr   r   r   r   r   r   r   r   r   r'   r!   strr   r   r   r   boolr   r   r   r   r   r/   rA   classmethodrF   rQ   rX   r-   r-   r-   r.   r
   "   s8   
 ;$r
   c                   @   s*   e Zd ZdZdefddZedd ZdS )r"   aj  
    A class for objects that include the ability to have its state
    be saved during `Trainer._save_checkpoint` and loaded back in during
    `Trainer._load_from_checkpoint`.

    These must implement a `state` function that gets called during the respective
    Trainer function call. It should only include parameters and attributes needed to
    recreate the state at a particular time, to avoid utilizing pickle/maintain standard
    file IO writing.

    Example:

    ```python
    class EarlyStoppingCallback(TrainerCallback, ExportableState):
        def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
            self.early_stopping_patience = early_stopping_patience
            self.early_stopping_threshold = early_stopping_threshold
            # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
            self.early_stopping_patience_counter = 0

        def state(self) -> dict:
            return {
                "args": {
                    "early_stopping_patience": self.early_stopping_patience,
                    "early_stopping_threshold": self.early_stopping_threshold,
                },
                "attributes": {
                    "early_stopping_patience_counter": self.early_stopping_patience_counter,
                }
            }
    ```returnc                 C   s   t d)Nz<You must implement a `state` function to utilize this class.)NotImplementedErrorr*   r-   r-   r.   r)      s   zExportableState.statec                 C   s8   | di |d }|d   D ]
\}}t||| q|S )NrN   
attributesr-   )itemsrM   )rD   r)   instancekvr-   r-   r.   
from_state   s   zExportableState.from_stateN)r&   rY   rZ   r[   r!   r)   ra   rj   r-   r-   r-   r.   r"      s
     r"   c                   @   st   e Zd ZU dZdZeed< dZeed< dZeed< dZ	eed< dZ
eed< dd	 Zd
d Zdd ZdefddZdS )TrainerControlaA  
    A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
    switches in the training loop.

    Args:
        should_training_stop (`bool`, *optional*, defaults to `False`):
            Whether or not the training should be interrupted.

            If `True`, this variable will not be set back to `False`. The training will just stop.
        should_epoch_stop (`bool`, *optional*, defaults to `False`):
            Whether or not the current epoch should be interrupted.

            If `True`, this variable will be set back to `False` at the beginning of the next epoch.
        should_save (`bool`, *optional*, defaults to `False`):
            Whether or not the model should be saved at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
        should_evaluate (`bool`, *optional*, defaults to `False`):
            Whether or not the model should be evaluated at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
        should_log (`bool`, *optional*, defaults to `False`):
            Whether or not the logs should be reported at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
    Fshould_training_stopshould_epoch_stopshould_saveshould_evaluate
should_logc                 C   
   d| _ dS )z<Internal method that resets the variable for a new training.FN)rl   rd   r-   r-   r.   _new_training     
zTrainerControl._new_trainingc                 C   rq   )z9Internal method that resets the variable for a new epoch.FN)rm   rd   r-   r-   r.   
_new_epoch  rs   zTrainerControl._new_epochc                 C   s   d| _ d| _d| _dS )z8Internal method that resets the variable for a new step.FN)rn   ro   rp   rd   r-   r-   r.   	_new_step  s   
zTrainerControl._new_steprb   c                 C   s    | j | j| j| j| jdi dS )Nrl   rm   rn   ro   rp   rN   re   rv   rd   r-   r-   r.   r)     s   zTrainerControl.stateN)r&   rY   rZ   r[   rl   r`   r]   rm   rn   ro   rp   rr   rt   ru   r!   r)   r-   r-   r-   r.   rk      s   
 rk   c                   @   sp  e Zd ZdZdededefddZdededefddZdededefd	d
Z	dededefddZ
dededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefdd Zdededefd!d"Zdededefd#d$Zd%S )&r   a0	  
    A class for objects that will inspect the state of the training loop at some events and take some decisions. At
    each of those events the following arguments are available:

    Args:
        args ([`TrainingArguments`]):
            The training arguments used to instantiate the [`Trainer`].
        state ([`TrainerState`]):
            The current state of the [`Trainer`].
        control ([`TrainerControl`]):
            The object that is returned to the [`Trainer`] and can be used to make some decisions.
        model ([`PreTrainedModel`] or `torch.nn.Module`):
            The model being trained.
        processing_class ([`PreTrainedTokenizer` or `BaseImageProcessor` or `ProcessorMixin` or `FeatureExtractionMixin`]):
            The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor.
        optimizer (`torch.optim.Optimizer`):
            The optimizer used for the training steps.
        lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
            The scheduler used for setting the learning rate.
        train_dataloader (`torch.utils.data.DataLoader`, *optional*):
            The current dataloader used for training.
        eval_dataloader (`torch.utils.data.DataLoader`, *optional*):
            The current dataloader used for evaluation.
        metrics (`dict[str, float]`):
            The metrics computed by the last evaluation phase.

            Those are only accessible in the event `on_evaluate`.
        logs  (`dict[str, float]`):
            The values to log.

            Those are only accessible in the event `on_log`.

    The `control` object is the only one that can be changed by the callback, in which case the event that changes it
    should return the modified version.

    The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`.
    You can unpack the ones you need in the signature of the event using them. As an example, see the code of the
    simple [`~transformers.PrinterCallback`].

    Example:

    ```python
    class PrinterCallback(TrainerCallback):
        def on_log(self, args, state, control, logs=None, **kwargs):
            _ = logs.pop("total_flos", None)
            if state.is_local_process_zero:
                print(logs)
    ```rN   r)   controlc                 K      dS )zS
        Event called at the end of the initialization of the [`Trainer`].
        Nr-   r*   rN   r)   rx   kwargsr-   r-   r.   on_init_endZ      zTrainerCallback.on_init_endc                 K   ry   )z<
        Event called at the beginning of training.
        Nr-   rz   r-   r-   r.   on_train_begin_  r}   zTrainerCallback.on_train_beginc                 K   ry   )z6
        Event called at the end of training.
        Nr-   rz   r-   r-   r.   on_train_endd  r}   zTrainerCallback.on_train_endc                 K   ry   )z<
        Event called at the beginning of an epoch.
        Nr-   rz   r-   r-   r.   on_epoch_begini  r}   zTrainerCallback.on_epoch_beginc                 K   ry   )z6
        Event called at the end of an epoch.
        Nr-   rz   r-   r-   r.   on_epoch_endn  r}   zTrainerCallback.on_epoch_endc                 K   ry   )z
        Event called at the beginning of a training step. If using gradient accumulation, one training step might take
        several inputs.
        Nr-   rz   r-   r-   r.   on_step_begins  r}   zTrainerCallback.on_step_beginc                 K   ry   )zv
        Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
        Nr-   rz   r-   r-   r.   on_pre_optimizer_stepy  r}   z%TrainerCallback.on_pre_optimizer_stepc                 K   ry   )z}
        Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
        Nr-   rz   r-   r-   r.   on_optimizer_step~  r}   z!TrainerCallback.on_optimizer_stepc                 K   ry   )zU
        Event called at the end of an substep during gradient accumulation.
        Nr-   rz   r-   r-   r.   on_substep_end  r}   zTrainerCallback.on_substep_endc                 K   ry   )z
        Event called at the end of a training step. If using gradient accumulation, one training step might take
        several inputs.
        Nr-   rz   r-   r-   r.   on_step_end  r}   zTrainerCallback.on_step_endc                 K   ry   )z9
        Event called after an evaluation phase.
        Nr-   rz   r-   r-   r.   on_evaluate  r}   zTrainerCallback.on_evaluatec                 K   ry   )z=
        Event called after a successful prediction.
        Nr-   )r*   rN   r)   rx   metricsr{   r-   r-   r.   
on_predict  r}   zTrainerCallback.on_predictc                 K   ry   )z7
        Event called after a checkpoint save.
        Nr-   rz   r-   r-   r.   on_save  r}   zTrainerCallback.on_savec                 K   ry   )z;
        Event called after logging the last logs.
        Nr-   rz   r-   r-   r.   on_log  r}   zTrainerCallback.on_logc                 K   ry   )z7
        Event called after a prediction step.
        Nr-   rz   r-   r-   r.   on_prediction_step  r}   z"TrainerCallback.on_prediction_stepc                 K   ry   )z
        Event called before pushing the model to the hub, at the beginning of Trainer.push_to_hub and Trainer._push_from_checkpoint.
        Nr-   rz   r-   r-   r.   on_push_begin  r}   zTrainerCallback.on_push_beginN)r&   rY   rZ   r[   r   r
   rk   r|   r~   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r-   r-   r-   r.   r   '  s$    1r   c                   @   s  e Zd ZdZdd Zdd Zdd Zdd	 Zed
d Z	de
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefdd Zde
dedefd!d"Zde
dedefd#d$Zde
dedefd%d&Zde
dedefd'd(Zde
dedefd)d*Zde
dedefd+d,Zde
dedefd-d.Zd/d0 Zd1S )2CallbackHandlerz>Internal class that just calls the list of callbacks in order.c                 C   sj   g | _ |D ]}| | q|| _|| _|| _|| _d | _d | _tdd | j D s3t	
d| j  d S d S )Nc                 s   s    | ]}t |tV  qd S N)r    DefaultFlowCallback.0cbr-   r-   r.   	<genexpr>  s    z+CallbackHandler.__init__.<locals>.<genexpr>zThe Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You
should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list ofcallbacks is
:)	callbacksadd_callbackmodelprocessing_class	optimizerlr_schedulertrain_dataloadereval_dataloaderanyloggerwarningcallback_list)r*   r   r   r   r   r   r   r-   r-   r.   __init__  s    zCallbackHandler.__init__c                 C   sh   t |tr| n|}t |tr|n|j}|dd | jD v r,td| dd | j  | j| d S )Nc                 S   s   g | ]}|j qS r-   )r%   )r   cr-   r-   r.   
<listcomp>  s    z0CallbackHandler.add_callback.<locals>.<listcomp>zYou are adding a zH to the callbacks of this Trainer, but there is already one. The currentzlist of callbacks is
:)r    r$   r%   r   r   r   r   r(   )r*   r+   r   cb_classr-   r-   r.   r     s   
zCallbackHandler.add_callbackc                 C   sd   t |tr| jD ]}t ||r| j| |  S qd S | jD ]}||kr/| j| |  S qd S r   r    r$   r   remover*   r+   r   r-   r-   r.   pop_callback  s   



zCallbackHandler.pop_callbackc                 C   sF   t |tr| jD ]}t ||r| j|  d S qd S | j| d S r   r   r   r-   r-   r.   remove_callback  s   


zCallbackHandler.remove_callbackc                 C   s   d dd | jD S )Nr4   c                 s   s    | ]}|j jV  qd S r   )r%   r&   r   r-   r-   r.   r     s    z0CallbackHandler.callback_list.<locals>.<genexpr>)joinr   rd   r-   r-   r.   r     s   zCallbackHandler.callback_listrN   r)   rx   c                 C      |  d|||S )Nr|   
call_eventr*   rN   r)   rx   r-   r-   r.   r|        zCallbackHandler.on_init_endc                 C      d|_ | d|||S )NFr~   )rl   r   r   r-   r-   r.   r~        zCallbackHandler.on_train_beginc                 C   r   )Nr   r   r   r-   r-   r.   r     r   zCallbackHandler.on_train_endc                 C   r   )NFr   )rm   r   r   r-   r-   r.   r     r   zCallbackHandler.on_epoch_beginc                 C   r   )Nr   r   r   r-   r-   r.   r     r   zCallbackHandler.on_epoch_endc                 C   s"   d|_ d|_d|_| d|||S )NFr   )rp   ro   rn   r   r   r-   r-   r.   r     s   zCallbackHandler.on_step_beginc                 C   r   )Nr   r   r   r-   r-   r.   r     r   z%CallbackHandler.on_pre_optimizer_stepc                 C   r   )Nr   r   r   r-   r-   r.   r     r   z!CallbackHandler.on_optimizer_stepc                 C   r   )Nr   r   r   r-   r-   r.   r     r   zCallbackHandler.on_substep_endc                 C   r   )Nr   r   r   r-   r-   r.   r     r   zCallbackHandler.on_step_endc                 C      d|_ | jd||||dS )NFr   r   )ro   r   r*   rN   r)   rx   r   r-   r-   r.   r   
     zCallbackHandler.on_evaluatec                 C   s   | j d||||dS )Nr   r   r   r   r-   r-   r.   r     s   zCallbackHandler.on_predictc                 C   r   )NFr   )rn   r   r   r-   r-   r.   r     r   zCallbackHandler.on_savec                 C   r   )NFr   )logs)rp   r   )r*   rN   r)   rx   r   r-   r-   r.   r     r   zCallbackHandler.on_logc                 C   r   )Nr   r   r   r-   r-   r.   r     r   z"CallbackHandler.on_prediction_stepc                 K   s   | j d|||fi |S )Nr   r   rz   r-   r-   r.   r     s   zCallbackHandler.on_push_beginc              
   K   sP   | j D ]"}t|||||f| j| j| j| j| j| jd|}|d ur%|}q|S )N)r   r   r   r   r   r   )r   rJ   r   r   r   r   r   r   )r*   eventrN   r)   rx   r{   r+   resultr-   r-   r.   r     s&   

zCallbackHandler.call_eventN)r&   rY   rZ   r[   r   r   r   r   propertyr   r   r
   rk   r|   r~   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r-   r-   r-   r.   r     s2    	
r   c                   @   s<   e Zd ZdZdededefddZdededefddZd	S )
r   zx
    A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
    rN   r)   rx   c                 K   s   |j dkr|jrd|_|jtjkr|j |j dkrd|_|jtjkr3|j |j dkr3|j	|j kr3d|_
|jtjkrI|jdkrI|j |j dkrId|_|j |jkr[d|_|jtjkr[d|_|S )Nr   Tr   )r   logging_first_steprp   logging_strategyr   STEPSr   eval_strategyr   
eval_delayro   save_strategyr   r   rn   r   rl   rz   r-   r-   r.   r   8  s"   
zDefaultFlowCallback.on_step_endc                 K   sF   |j tjkr	d|_|jtjkr|j|jkrd|_|jt	jkr!d|_
|S )NT)r   r   EPOCHrp   r   r   r   ro   r   r   rn   rz   r-   r-   r.   r   X  s   z DefaultFlowCallback.on_epoch_endN)	r&   rY   rZ   r[   r   r
   rk   r   r   r-   r-   r-   r.   r   3  s     r   c                   @   s\   e Zd ZdZddefddZdd Zdd	 ZdddZdd Z	dd Z
dddZdd Zd
S )ProgressCallbackz
    A [`TrainerCallback`] that displays the progress of training or evaluation.
    You can modify `max_str_len` to control how long strings are truncated when logging.
    d   max_str_lenc                 C   s   d| _ d| _|| _dS )a!  
        Initialize the callback with optional max_str_len parameter to control string truncation length.

        Args:
            max_str_len (`int`):
                Maximum length of strings to display in logs.
                Longer strings will be truncated with a message.
        N)training_barprediction_barr   )r*   r   r-   r-   r.   r   n  s   	
zProgressCallback.__init__c                 K   s    |j rt|jdd| _d| _d S )NT)totaldynamic_ncolsr   )r   r   r   r   current_steprz   r-   r-   r.   r~   {  s   
zProgressCallback.on_train_beginc                 K   s*   |j r| j|j| j  |j| _d S d S r   )r   r   updater   r   rz   r-   r-   r.   r     s   zProgressCallback.on_step_endNc                 K   sJ   |j r!t|r#| jd u rtt|| jd u dd| _| jd d S d S d S )NT)r   leaver   r   )r   r   r   r   lenr   r   )r*   rN   r)   rx   r   r{   r-   r-   r.   r     s   
z#ProgressCallback.on_prediction_stepc                 K   (   |j r| jd ur| j  d | _d S d S r   r   r   closerz   r-   r-   r.   r     
   


zProgressCallback.on_evaluatec                 K   r   r   r   rz   r-   r-   r.   r     r   zProgressCallback.on_predictc           
      K   s   |j rM| jd urOi }| D ].\}}t|tr,t|| jkr,dt| d| j d||< t|tr8|d||< q|||< q|dd }	| j	t| d S d S d S )Nz%[String too long to display, length: z > z/. Consider increasing `max_str_len` if needed.].4gr   )
r   r   rf   r    r_   r   r   r\   popr>   )
r*   rN   r)   rx   r   r{   shallow_logsrh   ri   _r-   r-   r.   r     s   

zProgressCallback.on_logc                 K   s   |j r| j  d | _d S d S r   )r   r   r   rz   r-   r-   r.   r     s   

zProgressCallback.on_train_end)r   r   )r&   rY   rZ   r[   r^   r   r~   r   r   r   r   r   r   r-   r-   r-   r.   r   h  s    

r   c                   @   s   e Zd ZdZdddZdS )PrinterCallbackz?
    A bare [`TrainerCallback`] that just prints the logs.
    Nc                 K   s<   | dd }|jr|d urdd | D }t| d S d S )Nr   c                 S   s(   i | ]\}}|t |tr|d n|qS )r   )r    r\   )r   rh   ri   r-   r-   r.   
<dictcomp>  s   ( z*PrinterCallback.on_log.<locals>.<dictcomp>)r   r   rf   print)r*   rN   r)   rx   r   r{   r   r-   r-   r.   r     s   zPrinterCallback.on_logr   )r&   rY   rZ   r[   r   r-   r-   r-   r.   r     s    r   c                   @   sN   e Zd ZdZddededB fddZd	d
 Zdd Zdd Z	de
fddZdS )EarlyStoppingCallbacka1  
    A [`TrainerCallback`] that handles early stopping.

    Args:
        early_stopping_patience (`int`):
            Use with `metric_for_best_model` to stop training when the specified metric worsens for
            `early_stopping_patience` evaluation calls.
        early_stopping_threshold(`float`, *optional*):
            Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the
            specified metric must improve to satisfy early stopping conditions. `

    This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric
    in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the
    early stopping will not occur until the next save step.
    r           early_stopping_patienceearly_stopping_thresholdNc                 C   s   || _ || _d| _d S )Nr   r   r   early_stopping_patience_counter)r*   r   r   r-   r-   r.   r     s   
zEarlyStoppingCallback.__init__c                 C   sX   |j rtjntj}|jd u s|||jr#t||j | jkr#d| _d S |  jd7  _d S )Nr   r   )greater_is_betternpgreaterlessr   absr   r   )r*   rN   r)   rx   metric_valueoperatorr-   r-   r.   check_metric_value  s   


z(EarlyStoppingCallback.check_metric_valuec                 K   s:   |j std |jd usJ d|jtjksJ dd S )NzUsing EarlyStoppingCallback without load_best_model_at_end=True. Once training is finished, the best model will not be loaded automatically.zBEarlyStoppingCallback requires metric_for_best_model to be definedzAEarlyStoppingCallback requires IntervalStrategy of steps or epoch)load_best_model_at_endr   r   metric_for_best_modelr   r   NOrz   r-   r-   r.   r~     s   z$EarlyStoppingCallback.on_train_beginc                 K   sl   |j }|dsd| }||}|d u r!td| d d S | |||| | j| jkr4d|_d S d S )Neval_z@early stopping required metric_for_best_model, but did not find z so early stopping is disabledT)	r   
startswithgetr   r   r   r   r   rl   )r*   rN   r)   rx   r   r{   metric_to_checkr   r-   r-   r.   r     s   




z!EarlyStoppingCallback.on_evaluaterb   c                 C   s   | j | jdd| jidS )N)r   r   r   rw   r   rd   r-   r-   r.   r)     s   zEarlyStoppingCallback.state)r   r   )r&   rY   rZ   r[   r^   r\   r   r   r~   r   r!   r)   r-   r-   r-   r.   r     s    r   )r[   r;   r9   rK   r   numpyr   	tqdm.autor   trainer_utilsr   r   r   training_argsr   utilsr	   
get_loggerr&   r   r
   r"   rk   r   r   r   r   r   r   r-   r-   r-   r.   <module>   s2   
 ,=  5J