o
    0i">                     @   s~   d dl Z d dlZd dlmZ d dlmZmZmZmZ d dl	m
Z
mZmZmZ d dlmZ er4d dlmZ eG dd dZdS )	    N)	dataclass)TYPE_CHECKINGLiteralOptionalUnion)DeepSpeedSequenceParallelConfigDistributedTypeTorchContextParallelConfigTorchTensorParallelConfig)is_torch_version)Acceleratorc                   @   s  e Zd ZU dZdZee ed< dZee ed< dZ	ee ed< dZ
ee ed< dZed ed< dZee ed	< dZed
 ed< dZedef ed< dZedef ed< dZedef ed< dZdd Zdd Zedd Zedd Zedd Zedd Zedd Zedd Zedd  Z ed!d" Z!ed#d$ Z"ed%d& Z#ed'd( Z$ed)d* Z%ed+d, Z&ed-d. Z'd/e(fd0d1Z)dBd/ee( fd2d3Z*d4e+e+ed5f e+e(d5f f fd6d7Z,d8d9 Z-d:e(d;efd<d=Z.dCd@dAZ/dS )DParallelismConfiga  
    A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`
    https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py

    Args:
        dp_replicate_size (`int`, defaults to `1`):
            The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication
            group will not be used.
        dp_shard_size (`int`, defaults to `1`):
            The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also
            be greater than 1, as composing DDP + TP is currently not supported.
        tp_size (`int`, defaults to `1`):
            The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be
            used.
        tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):
            The handler for the tensor parallel group.
        cp_size (`int`, defaults to `1`):
            The size of the context parallel group. Currently not supported, but reserved for future use and enabled
            for downstream libraries.
        cp_backend (`str`, defaults to `torch`):
            Which CP backend to use: `torch` (FSDP2)
        sp_size (`int`, defaults to `1`):
            The size of the sequence parallel group.
        sp_backend (`str`, defaults to `deepspeed`):
            Which SP backend to use:`deepspeed` (ALST/Ulysses)

    You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
    together:
        - `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).
        - `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).
        - `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use
          `DistributedDataParallelKwargs` instead.

    Ndp_replicate_sizedp_shard_sizetp_sizecp_sizetorch
cp_backendsp_size	deepspeed
sp_backend
tp_handler
cp_handler
sp_handlerc                 C   sV   d| j  d| j d| j d| j d| j d| j d| j d| j d	| j d
| j	 dS )Nz'ParallelismConfig(
 	dp_replicate_size=z,
	dp_shard_size=z,
	tp_size=z,
	cp_size=z,
	cp_backend=z,
	sp_size=z,
	sp_backend=z,
	total_size=z
	tp_handler=z,
	cp_handler=z)
)
r   r   r   r   r   r   r   
total_sizer   r   self r   _/sda-disk/www/egybert/egybert_env/lib/python3.10/site-packages/accelerate/parallelism_config.py__repr__U   s,   	
zParallelismConfig.__repr__c                    s2   dd l dg  fdd| j D  d S )Nr   device_meshc                    s4   i | ]\}}| vr|t |d r|jn|qS )__dict__)hasattrdeepcopyr!   ).0kv_non_serializable_fieldscopyr   r   
<dictcomp>j   s
    z-ParallelismConfig.to_json.<locals>.<dictcomp>)r)   r#   r!   itemsr   r   r'   r   to_jsond   s   zParallelismConfig.to_jsonc                 C   (   g }| j r
|dg7 }| jr|dg7 }|S )zENames of enabled dimensions across which data parallelism is applied.dp_replicatedp_shard)dp_replicate_enableddp_shard_enabledr   dimsr   r   r   dp_dim_namesq      

zParallelismConfig.dp_dim_namesc                 C   8   g }| j r
|dg7 }| jr|dg7 }| jr|dg7 }|S )z]Names of enabled dimensions which will receive the same batch (non-data parallel dimensions).tpcpsp)
tp_enabled
cp_enabled
sp_enabledr2   r   r   r   non_dp_dim_names{      


z"ParallelismConfig.non_dp_dim_namesc                 C   r-   )zlNames of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP.r/   r8   )r1   r;   r2   r   r   r   dp_shard_cp_dim_names   r5   z'ParallelismConfig.dp_shard_cp_dim_namesc                 C   r6   )z@Names of enabled dimensions across which loss should be averagedr.   r/   r8   )r0   r1   r;   r2   r   r   r   dp_cp_dim_names   r>   z!ParallelismConfig.dp_cp_dim_namesc                 C   s"   g }| j r
|dg7 }|dg7 }|S )z^Names of enabled dimensions across which FSDP is applied, including data parallel replication.r.   dp_shard_cp)r0   r2   r   r   r   fsdp_dim_names   s
   

z ParallelismConfig.fsdp_dim_namesc                 C   s   | j | j | j | j | j S )zSThe total size of the parallelism configuration, which is the product of all sizes.)r   r   r   r   r   r   r   r   r   r      s   zParallelismConfig.total_sizec                 C   s   | j | j | j S )zhThe size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes.)r   r   r   r   r   r   r   non_data_parallel_size   s   z(ParallelismConfig.non_data_parallel_sizec                 C   s   | j | j S )z_The size of the data parallel dimensions, which is the product of data parallel replication and)r   r   r   r   r   r   data_parallel_size      z$ParallelismConfig.data_parallel_sizec                 C   
   | j dkS )zKTrue if data parallel replication is enabled, i.e. `dp_replicate_size > 1`.   )r   r   r   r   r   r0         
z&ParallelismConfig.dp_replicate_enabledc                 C   rF   )zDTrue if data parallel sharding is enabled, i.e. `dp_shard_size > 1`.rG   )r   r   r   r   r   r1      rH   z"ParallelismConfig.dp_shard_enabledc                 C   rF   )z:True if tensor parallelism is enabled, i.e. `tp_size > 1`.rG   )r   r   r   r   r   r:      rH   zParallelismConfig.tp_enabledc                 C   rF   )z;True if context parallelism is enabled, i.e. `cp_size > 1`.rG   )r   r   r   r   r   r;      rH   zParallelismConfig.cp_enabledc                 C   rF   )z;True if context parallelism is enabled, i.e. `sp_size > 1`.rG   )r   r   r   r   r   r<      rH   zParallelismConfig.sp_enabledc                 C   s   | j | j S )z$Names of all active mesh dimensions.)r4   r=   r   r   r   r   active_mesh_dims   rE   z"ParallelismConfig.active_mesh_dimsdevice_typec                 C   s   t ddrddlm} ntd|  }t|dkrdS |\}}||||d}| jr2|| j d | jr=|| j d	 | j	rH|| j	 d
 |S )a!  Builds a device mesh for the given device type based on the parallelism configuration.
        This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).

        Args:
            device_type (`str`): The type of device for which to build the mesh, e
        z>=z2.2.0r   )init_device_meshz4Building a device_mesh requires to have torch>=2.2.0N)mesh_dim_namesdprA   dp_cp)
r   torch.distributed.device_meshrK   RuntimeError	_get_meshlenr4   _flattenr?   r@   )r   rJ   rK   meshrL   
mesh_shaper    r   r   r   build_device_mesh   s&   
z#ParallelismConfig.build_device_meshc                 C   s\   | j d u r|d ur| || _ | j S d|d ur+| j j|kr+td| j j d| d| j S )Nz@You need to pass a device_type e.g cuda to build the device meshz4The device_mesh is already created with device type z@. However, you are trying to get a device mesh with device_type z<. Please check if you correctly initialized your device_mesh)r    rV   rJ   
ValueError)r   rJ   r   r   r   get_device_mesh   s   
	z!ParallelismConfig.get_device_meshreturn.c                    s@   fddj D }g d t|  fddd}tt| S )zQGenerate mesh shape and dimension names for torch.distributed.init_device_mesh().c                    s   i | ]}| j | qS r   )_sizes)r$   parallelismr   r   r   r*     s    z/ParallelismConfig._get_mesh.<locals>.<dictcomp>)r.   r/   r8   r9   r7   c                    s     | d S )Nr   )index)x)
mesh_orderr   r   <lambda>	  s    z-ParallelismConfig._get_mesh.<locals>.<lambda>)key)rI   sortedr+   tuplezip)r   	mesh_dimssorted_itemsr   )r^   r   r   rQ      s   
zParallelismConfig._get_meshc                 C   s  | j d u rttjdd| _ | jd u rttjdd| _| jd u r-ttjdd| _| jd u r<ttjdd| _| jd u rItjdd| _| j	d u rXttjdd| _	| j
d u retjd	d
| _
| jdkrs| jd u rst | _| jdkr| jd u rt | _n"ttd}t| j|| j std| j d|| j  dt| j | j	dkr| jd u rt | _| j dk rtd| j  | jdk rtd| j | jdk rtd| j | jdk rtd| j dg}| j|vrtd| d| j | j	dk rtd| j	 d
g}| j
|vrtd| d| j
 | jdks'| jdkr7| j dkr7| jdkr7td| j | j| j| j| j	d| _d S )N$PARALLELISM_CONFIG_DP_REPLICATE_SIZE1 PARALLELISM_CONFIG_DP_SHARD_SIZEPARALLELISM_CONFIG_TP_SIZEPARALLELISM_CONFIG_CP_SIZEPARALLELISM_CONFIG_CP_BACKENDr   PARALLELISM_CONFIG_SP_SIZEPARALLELISM_CONFIG_SP_BACKENDr   rG   )r   zParallelismConfig's cp_backend=z
 requires z, but cp_handler was set to z.dp_replicate_size must be at least 1, but got z*dp_shard_size must be at least 1, but got z$tp_size must be at least 1, but got z$cp_size must be at least 1, but got zcp_backend must be one of z
, but got z$sp_size must be at least 1, but got zsp_backend must be one of aC  Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel.)r.   r/   r7   r8   r9   )r   intosenvirongetr   r   r   r   r   r   r   r
   r   r	   dict
isinstancerW   typer   r   rZ   )r   cp_backends_config_mapvalid_cp_backendsvalid_sp_backendsr   r   r   __post_init__  sp   











"






0zParallelismConfig.__post_init__r[   sizec                 C   sB   || j  v sJ d| j   || j |< t| | d| d S )NzParallelism must be one of _size)rZ   keyssetattr)r   r[   ry   r   r   r   	_set_sizeQ  s   "
zParallelismConfig._set_sizeacceleratorr   c                 C   s  t  }|js| jdkrd S | jdkr| d|j | j|jkr,td| j d|j d| jdkrF|jsF|jsF|jtj	ksFtd|j d| j
 D ]\}}|dkrjt| | dd d urj|d	| d
| d qK|r}|jrtdd| t d S d S d S )NrG   r.   zParallelismConfig total_size (z ) does not match num_processes (zJ). Please adjust dp_replicate_size/ dp_shard_size/tp_size/cp_size/sp_size.zParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{Device} or DistributedType.DEEPSPEED, but got ._handlerzParallelismConfig.z_handler is set, but z0_size is set to 1. This handler will be ignored.z.ParallelismConfig has the following warnings:

)setmulti_devicer   r}   num_processesrW   is_fsdp2distributed_typer   	DEEPSPEEDrZ   r+   getattraddis_main_processwarningswarnjoinUserWarning)r   r~   	_warningsr[   ry   r   r   r   _validate_acceleratorV  s@   



z'ParallelismConfig._validate_accelerator)N)r~   r   )0__name__
__module____qualname____doc__r   r   rn   __annotations__r   r   r   r   r   r   r   r   r   r
   r   r	   r   r   r    r   r,   propertyr4   r=   r?   r@   rB   r   rC   rD   r0   r1   r:   r;   r<   rI   strrV   rX   rb   rQ   rx   r}   r   r   r   r   r   r   !   sb   
 #
	

	










&Dr   )ro   r   dataclassesr   typingr   r   r   r   accelerate.utils.dataclassesr   r   r	   r
   accelerate.utils.versionsr   
accelerater   r   r   r   r   r   <module>   s   