"""Base models for the core app of the Marsha project."""
from datetime import date, datetime
from typing import Any, Dict, List, Mapping, Sequence, Tuple, Type, get_type_hints
from django.core import checks
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models.fields.related import RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from psqlextra.indexes import ConditionalUniqueIndex
from safedelete.models import SOFT_DELETE_CASCADE, SafeDeleteModel
from marsha.stubs import M2MType, ReverseFKType, TupleOfStr, Typing
CheckMessages = List[checks.CheckMessage] # pylint: disable=invalid-name
CHECKED_APPS = {"core"}
fields_type_mapping: Mapping[Type[models.Field], type] = {
models.AutoField: int,
models.BooleanField: bool,
models.CharField: str,
models.DateField: date,
models.DateTimeField: datetime,
models.FloatField: float,
models.IntegerField: int,
models.TextField: str,
}
reverse_fields_type_mapping: Mapping[Type[RelatedField], type] = {
models.ForeignKey: ReverseFKType,
models.ManyToManyField: M2MType,
}
def _get_fields_by_source_model(
model: Type[models.Model]
) -> Dict[str, Type[models.Model]]:
"""Return all fields of a model and the exact model where they are defined.
Parameters
----------
model: Type[models.Model]
The model for which to get the fields origin.
Returns
-------
Dict[str, Type[models.Model]]
A dict with keys being the fields names, and values the model on which the
field is defined.
"""
fields: Dict[str, Type[models.Model]] = {}
for base in model.__bases__:
if base is models.Model or not issubclass(base, models.Model):
continue
fields.update(_get_fields_by_source_model(base))
for field in model._meta.get_fields():
if field.name in fields:
continue
fields[field.name] = model
return fields
[docs]class NonDeletedUniqueIndex(ConditionalUniqueIndex):
"""A special ConditionalUniqueIndex for non deleted objects."""
condition: str = '"deleted" IS NULL'
[docs] def __init__(self, fields: Sequence, name: str = None) -> None:
"""Override default init to pass our predefined condition.
For the parameters, see ``ConditionalUniqueIndex.__init__``.
"""
super().__init__(condition=self.condition, fields=fields, name=name)
[docs] def deconstruct(self): # type: ignore
"""Remove ``condition`` as an argument to be defined in migrations."""
path, args, kwargs = super().deconstruct()
del kwargs["condition"]
return path, args, kwargs
[docs]class BaseModel(SafeDeleteModel):
"""Base model for all our models.
It is based on ``SafeDeleteModel`` to easily manage how we want the instances
to be deleted/soft-deleted, with or without its relationships.
The default ``safedelete`` policy is ``SOFT_DELETE_CASCADE``, ie the object to
delete and its relations will be soft deleted: their ``deleted`` field will be
filled with the current date-time (the opposite, ``None``, is the same as
"not deleted")
Also it adds some checks run with ``django check``:
- check that all fields are correctly annotated.
- same for fields pointing to other models: final models must have all related
names properly annotated.
- check that every ``ManyToManyField`` use a defined ``through`` table.
- check that every model have a ``db_table`` defined, not prefixed with the name
of the app or the project.
"""
_safedelete_policy = SOFT_DELETE_CASCADE
class Meta:
"""Options for the ``BaseModel`` model."""
abstract = True
@classmethod
def _get_expected_field_annotation(cls, field: models.Field) -> Tuple[type, str]:
"""Get the expected annotation for a django model field.
Parameters
----------
field: models.Field
The instance of a field of the ``cls`` model for which we want to know
the expected annotation.
Returns
-------
Tuple[type, str]
A tuple with two entries: the expected type annotation, and a string representation
of this annotation.
Raises
------
ValueError
If the field is not a ``ManyToManyField``, a ``ForeignKey``, a ``OneToOneField``
and not of a type defined in ``fields_type_mapping``.
"""
if isinstance(field, models.ManyToManyField):
return M2MType[field.related_model], '{}["{}"]'.format( # type: ignore
M2MType.__name__, field.related_model._meta.object_name
)
if isinstance(field, models.ForeignKey): # covers OneToOneField too
return field.related_model, '"{}"'.format(
field.related_model._meta.object_name
)
field_class: Type[models.Field]
field_type: Typing
for field_class, field_type in fields_type_mapping.items():
if isinstance(field, field_class):
return field_type, field_type.__name__
raise ValueError(
"Field type not yet managed for '{}': {}".format(
field.name, field.__class__.__name__
)
)
@classmethod
def _check_annotated_fields(cls) -> CheckMessages:
"""Check that all fields are correctly annotated.
Returns
-------
List[checks.CheckMessage]
A list of the check messages representing problems found on the model.
"""
from marsha.core import (
models as core_models
) # imported here to avoid cyclic import
fields: List[models.Field] = [
field
for field in cls._meta.get_fields()
if field.concrete and not field.auto_created
]
fields_by_model: Dict[str, Type[models.Model]] = _get_fields_by_source_model(
cls
)
errors: CheckMessages = []
model_full_name: str = "{}.{}".format(
cls._meta.app_label, cls._meta.object_name
)
field: models.Field
for field in fields:
field_name: str = field.name
try:
expected_annotation = cls._get_expected_field_annotation(field)
except ValueError:
errors.append(
checks.Error(
"The expected annotation for the field '{}' on the model '{}', a "
"'{}' is not known: please define it in 'fields_type_mapping'".format(
field_name, model_full_name, field.__class__
),
obj=cls,
id="marsha.models.E001",
)
)
continue
# check that the field is annotated
if field_name not in cls.__annotations__:
# ignore if defined in a model outside of the scope
if fields_by_model[field_name]._meta.app_label not in CHECKED_APPS:
continue
errors.append(
checks.Error(
"There is no typing annotation for the field '{}' "
"on the model '{}'".format(field_name, model_full_name),
hint="Add the annotation for the '{}' field on "
"the model '{}': ': {}'".format(
field_name, model_full_name, expected_annotation[1]
),
obj=cls,
id="marsha.models.E002",
)
)
continue
# check that the field is correctly annotated
annotation_type = get_type_hints(cls, core_models.__dict__)[field_name]
if annotation_type != expected_annotation[0]:
errors.append(
checks.Error(
"The typing annotation is wrong for the field '{}' on "
"the model '{}': it should be '{}', not '{}'".format(
field_name,
model_full_name,
expected_annotation[1],
annotation_type,
),
hint="Change the annotation for the '{}' field on the model "
"'{}': ': {}'".format(
field_name, model_full_name, expected_annotation[1]
),
obj=cls,
id="marsha.models.E003",
)
)
continue
# expected_annotation_type = sel
return errors
@classmethod
def _check_annotated_related_names(cls) -> CheckMessages:
"""Check that all related names are defined and annotated on final models.
Returns
-------
List[checks.CheckMessage]
A list of the check messages representing problems found on the model.
"""
from marsha.core import (
models as core_models
) # imported here to avoid cyclic import
related_fields: List[ForeignObjectRel] = [
field
for field in cls._meta.get_fields()
if field.is_relation
and hasattr(field, "related_name")
and field.field.model._meta.app_label in CHECKED_APPS
]
errors: CheckMessages = []
model_full_name: str = "{}.{}".format(
cls._meta.app_label, cls._meta.object_name
)
field: ForeignObjectRel
for field in related_fields:
field_name: str = field.field.name
related_name: str = field.related_name
related_model: Type[models.Model] = field.field.model
related_model_name: str = related_model._meta.object_name
related_model_full_name: str = "{}.{}".format(
related_model._meta.app_label, related_model_name
)
# first, check that related name are defined on fk/m2m/o2o fields
if not related_name:
errors.append(
checks.Error(
"The field '{}' on the model '{}', pointing to the model '{}' "
"doesn't have the 'related_name' attribute defined.".format(
field_name, related_model_full_name, model_full_name
),
hint="Set the 'related_name' argument when declaring the "
"'{}' field on the model '{}'".format(
field_name, related_model_full_name
),
obj=cls,
id="marsha.models.E004",
)
)
continue
# then check that related names names are annotated on final models
expected_annotation_type: Typing
expected_annotation_string: str
if field.multiple: # reverse relation of a ForeignKey or ManyToManyField
expected_annotation_type = reverse_fields_type_mapping[ # type: ignore
field.field.__class__
][related_model]
expected_annotation_string = '{}["{}"]'.format(
reverse_fields_type_mapping[field.field.__class__].__name__,
related_model_name,
)
else: # reverse relation of a OneToOneField
expected_annotation_type = related_model
expected_annotation_string = '"{}"'.format(related_model_name)
if related_name not in cls.__annotations__:
errors.append(
checks.Error(
"There is no typing annotation for the related_name '{}' on the "
"model '{}', pointed by the field '{}' defined on the model '{}'".format(
related_name,
model_full_name,
field_name,
related_model_full_name,
),
hint="Add '{}: {}' in the model '{}'".format(
related_name, expected_annotation_string, model_full_name
),
obj=cls,
id="marsha.models.E005",
)
)
continue
# and finally check that annotations of these related names are correct
annotation_type = get_type_hints(cls, core_models.__dict__)[related_name]
if annotation_type != expected_annotation_type:
errors.append(
checks.Error(
"The typing annotation is wrong for the related_name '{}' on "
"the model '{}', pointed by the field '{}' defined on the "
"model '{}': it should be '{}', not '{}'".format(
related_name,
model_full_name,
field_name,
related_model_full_name,
expected_annotation_string,
annotation_type,
),
hint="Change to '{}: {}' in the model '{}'".format(
related_name, expected_annotation_string, model_full_name
),
obj=cls,
id="marsha.models.E006",
)
)
continue
return errors
@classmethod
def _check_table_name(cls) -> CheckMessages:
"""Check that the table name is correctly defined.
Returns
-------
List[checks.CheckMessage]
A list of the check messages representing problems found on the model.
"""
errors: CheckMessages = []
model_full_name: str = "{}.{}".format(
cls._meta.app_label, cls._meta.object_name
)
try:
db_table: str = cls._meta.original_attrs["db_table"]
except KeyError:
errors.append(
checks.Error(
"The model '{}' must define the 'db_table' attribute on its "
"'Meta' class. It must not be prefixed with the name of the "
"app or the project.".format(model_full_name),
hint="Add 'db_table: str = \"{}\"' to the 'Meta' class of the "
"model '{}'".format(cls._meta.model_name, model_full_name),
obj=cls,
id="marsha.models.E007",
)
)
else:
app_prefix = cls._meta.app_label
module_prefix = cls.__module__.split(".")[0]
for prefix in [
app_prefix,
module_prefix,
app_prefix + "_",
module_prefix + "_",
]:
if db_table.startswith(prefix):
errors.append(
checks.Error(
"The model 'db_table' attribute of the model '{}' must not "
"be prefixed with the name of the app ('{}') or the project "
"('{}').".format(
model_full_name, app_prefix, module_prefix
),
hint="Change to 'db_table: str = \"{}\"' in the 'Meta' class of the "
"model '{}'".format(cls._meta.model_name, model_full_name),
obj=cls,
id="marsha.models.E008",
)
)
break
return errors
@classmethod
def _check_through_models(cls) -> CheckMessages:
"""Check that all m2m fields have a defined ``through`` model.
Returns
-------
List[checks.CheckMessage]
A list of the check messages representing problems found on the model.
"""
fields_by_model: Dict[str, Type[models.Model]] = _get_fields_by_source_model(
cls
)
errors: CheckMessages = []
model_full_name: str = "{}.{}".format(
cls._meta.app_label, cls._meta.object_name
)
m2m_fields: List[models.ManyToManyField] = [
field
for field in cls._meta.get_fields()
if isinstance(field, models.ManyToManyField)
]
field: models.ManyToManyField
for field in m2m_fields:
# ignore if defined in a model outside of the scope
if fields_by_model[field.name]._meta.app_label not in CHECKED_APPS:
continue
if field.remote_field.through._meta.auto_created:
errors.append(
checks.Error(
"The field '{}' of the model '{}' is a ManyToManyField but "
"without a 'through' model defined".format(
field.name, model_full_name
),
hint="Add the attribute 'through' to the field '{}' of the model '{}' "
"and define the appropriate model".format(
field.name, model_full_name
),
obj=cls,
id="marsha.models.E009",
)
)
return errors
[docs] @classmethod
def check(cls, **kwargs: Any) -> CheckMessages:
"""Add checks for related names.
Parameters
----------
kwargs: Any
Actually not used but asked by django to be present "for possible future usage".
Returns
-------
List[checks.CheckMessage]
A list of the check messages representing problems found on the model.
"""
errors = super().check(**kwargs)
errors.extend(cls._check_table_name())
errors.extend(cls._check_through_models())
errors.extend(cls._check_annotated_fields())
errors.extend(cls._check_annotated_related_names())
return errors
[docs] def validate_unique(self, exclude: List[str] = None) -> None:
"""Add validation for our ``NonDeletedUniqueIndex`` replacing ``unique_together``.
For the parameters, see ``django.db.models.base.Model.validate_unique``.
"""
super().validate_unique(exclude)
if not self.deleted:
# these uniqueness checks only make sense for non deleted instances
# because it's the condition of these unique-together fields
unique_checks = self._get_conditional_non_deleted_unique_checks(exclude)
if unique_checks:
all_objects = self.__class__.all_objects # type: ignore
try:
# we need to force ``_perform_unique_checks`` from ``SafeDeleteModel``
# to use the default manager to ignore deleted instances
self.__class__.all_objects = self.__class__.objects # type: ignore
errors = self._perform_unique_checks(unique_checks)
finally:
self.__class__.all_objects = all_objects # type: ignore
if errors:
raise ValidationError(errors)
@classmethod
def _get_conditional_non_deleted_indexes_fields(cls) -> List[List[str]]:
"""Get the tuples of fields for our conditional unique index for non deleted entries.
Returns
-------
List[List[str]]
A list with one entry for each matching index. Each entry is a tuple with
all the fields that compose the unique index.
"""
tuples: List[List[str]] = []
if cls._meta.indexes:
tuples.extend(
[
index.fields
for index in cls._meta.indexes
if isinstance(index, NonDeletedUniqueIndex)
]
)
return tuples
def _get_conditional_non_deleted_unique_checks(
self, exclude: List[str] = None
) -> List[Tuple[Type["BaseModel"], TupleOfStr]]:
"""Extract "unique checks" from our conditional unique for ``_perform_unique_checks``.
This is basically a copy of ``django.db.models.Model_get_unique_checks`` (only the
part for ``unique_together``), but using output from
``_get_conditional_non_deleted_indexes_fields`` instead of ``_meta.unique_together``.
For the parameters, see ``django.db.models.base.Model._get_unique_checks``.
Returns
-------
List[Tuple[Type["BaseModel"], TupleOfStr]]
``_perform_unique_checks`` expect a list with each entry being a tuple with:
- a model on which the uniqueness is defined
- a tuple of fields to be unique together for this model
"""
if exclude is None:
exclude = []
unique_checks: List[Tuple[Type["BaseModel"], TupleOfStr]] = []
unique_togethers = [
(self.__class__, self._get_conditional_non_deleted_indexes_fields())
]
parent_class: Type[models.Model]
for parent_class in self._meta.get_parent_list():
if issubclass(parent_class, BaseModel):
unique_togethers.append(
(
parent_class,
# pylint: disable=protected-access
parent_class._get_conditional_non_deleted_indexes_fields(),
)
)
for model_class, unique_together in unique_togethers:
if not unique_together:
continue
for check in unique_together:
for name in check:
# If this is an excluded field, don't add this check.
if name in exclude:
break
else:
unique_checks.append((model_class, tuple(check)))
return unique_checks