跳转至

drf-writable-nested嵌套新增更新

安装

本文requirement环境如下:

Django==2.0.4
djangorestframework==3.8.2
drf-writable-nested==0.4.2

外键或O2O直接嵌套新增

模型

# 用户
class User(AbstractBaseUser, PermissionsMixin):
    # 手机号码
    tel = CharNullField(max_length=20,
                        unique=True,
                        null=True,
                        error_messages={
                            'unique': "具有该手机号码的用户已存在",
                        },
                        verbose_name='手机号码')
    # 姓名
    name = models.CharField(max_length=30,
                            blank=True,
                            verbose_name='姓名')


# 员工
class Staff(models.Model):
    # 用户 级联删除
    user = models.OneToOneField('user.User',
                                null=True,
                                on_delete=models.CASCADE,
                                verbose_name='对应用户')
    # 工号
    no = models.CharField(max_length=255,
                          unique=True,
                          verbose_name='工号')

视图

views.py

# 员工
class StaffViewSet(ModelViewSet):
    """
    create:
    新增数据格式:
    <p><code>{
            'user': {'name': 'YKH', 'tel': '17749503263', 'email': 'ykh@dreamgo.tech', 'password': 'YKH123456',
                     'is_active': True, 'is_staff': True, 'groups': [self.group.id]},
            'job_state': 1, 'bank_no': '614457662', 'social_security_no': '暂无', 'on_job': True
        }</code>
    partial_update:
    更新数据格式:
        <p><code>{
            'user': {'id': user.id, 'name': 'YKH', 'tel': '17749503263', 'email': 'ykh@dreamgo.tech',
             'password': 'YKH123456', 'is_active': True, 'is_staff': True, 'groups': [self.group.id]},
            'job_state': 1, 'bank_no': '614457662', 'social_security_no': '暂无', 'on_job': True
        }</code>
    """
    queryset = Staff.objects.all()
    serializer_class = StaffCreateSerializer

serializers.py

# 创建用户
class StaffUserCreateSerializer(ModelSerializer):
    def to_internal_value(self, data):
        if 'id' in data and 'id' in self.fields:
            try:
                obj_id = self.fields['id'].to_internal_value(data['id'])
            except serializers.ValidationError as exc:
                raise serializers.ValidationError({'id': exc.detail})
            for field in self.fields.values():
                for validator in field.validators:
                    from rest_framework.validators import UniqueValidator
                    if type(validator) == UniqueValidator:
                        # Exclude id from queryset for checking uniqueness
                        validator.queryset = validator.queryset.exclude(id=obj_id)
        return super(StaffUserCreateSerializer, self).to_internal_value(data)

    def create(self, validated_data):
        # print('StaffUserCreateSerializer->create')
        groups = validated_data.pop('groups', None)
        # 使用create_user不是create
        instance = User.objects.create_user(**validated_data)
        if groups is not None:
            instance.groups.add(*groups)
        return instance

    def update(self, instance, validated_data):
        # print('StaffUserCreateSerializer->update')
        password = validated_data.pop('password', None)
        super(StaffUserCreateSerializer, self).update(instance, validated_data)
        if password is not None:
            instance.set_password(password)
            instance.save()
        return instance

    def validate_password(self, value):
        # 验证密码格式
        password_validation.validate_password(value)
        return value

    class Meta:
        model = User
        fields = ('id', 'name', 'tel', 'email', 'password', 'is_active', 'is_staff', 'groups')

class StaffCreateSerializer(WritableNestedModelSerializer):
    user = StaffUserCreateSerializer(label='用户')

    class Meta:
        model = Staff
        fields = ('user', 'no', 'job_state', 'bank_no', 'social_security_no', 'on_job')

tests.py:

class StaffTests(APITestCase):
    def setUp(self):
        self.user = User.objects.create_user(tel='18094213193', password='123456')
        self.user.user_permissions.add(*get_model_permission(Staff))
        token, _ = Token.objects.get_or_create(user=self.user)
        self.access_token = token.access_token
        # 数据
        self.group = Group.objects.create(name='测试')
        self.data = {
            'user': {'name': 'YKH', 'tel': '17749503263', 'email': 'ykh@dreamgo.tech', 'password': 'YKH123456',
                     'is_active': True, 'is_staff': True, 'groups': [self.group.id]},
            'job_state': 1, 'bank_no': '614457662', 'social_security_no': '暂无', 'on_job': True
        }
        self.data1 = {
            'user': {'name': 'YKH1', 'tel': '177495023263', 'email': 'y2kh@dreamgo.tech', 'password': 'YKH123456',
                     'is_active': True, 'is_staff': True, 'groups': [self.group.id]},
            'job_state': 1, 'bank_no': '614457662', 'social_security_no': '暂无', 'on_job': True
        }

    def test_staff_create(self):
        """创建"""
        url = reverse('staff-list')
        self.client.credentials(HTTP_AUTHORIZATION='Token ' + self.access_token)
        response = self.client.post(url, self.data, format='json')
        print('创建', json.dumps(response.data, ensure_ascii=False, indent=2))
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        Staff.objects.all().update(no='Staff002')
        response = self.client.post(url, self.data1, format='json')
        print('创建', json.dumps(response.data, ensure_ascii=False, indent=2))
        self.assertEqual(response.status_code, status.HTTP_200_OK)

    def test_staff_update(self):
        """更新"""
        user = User.objects.create_user(tel='17749503263', password='YKH123456')
        instance = Staff.objects.create(user=user, no='17749503263')
        print('原始:', instance.user.id)
        print('旧密码登陆结果:', authenticate(tel='17749503263', password='YKH123456'))
        url = reverse('staff-detail', kwargs={'pk': instance.id})
        self.client.credentials(HTTP_AUTHORIZATION='Token ' + self.access_token)
        data = {
            'user': {'id': 2, 'name': 'YKH1', 'tel': '17749503263', 'password': 'YKH654321'},
            'job_state': 0, 'bank_no': '614457662'
        }
        response = self.client.patch(url, data, format='json')
        print('更新', json.dumps(response.data, ensure_ascii=False, indent=2))
        print('新密码登陆结果:', authenticate(tel='17749503263', password='YKH654321'))
        self.assertEqual(response.status_code, status.HTTP_200_OK)

测试输出:

创建 {
  "id": 1,
  "user": {
    "id": 2,
    "tel": "17749503263",
    "portrait": "http://testserver/media/default/user/default.png",
    "groups": [
      {
        "id": 1,
        "name": "只读"
      },
      {
        "id": 2,
        "name": "测试"
      }
    ],
    "last_login": "",
    "gender": 2,
    "name": "YKH",
    "birth_day": "",
    "email": "ykh@dreamgo.tech",
    "fixed_tel": "",
    "id_no": "",
    "qq": "",
    "contact_address": "",
    "brief_code": "",
    "get_gender_display": "保密",
    "get_full_name": "YKH",
    "is_staff": true,
    "is_superuser": false,
    "is_active": true
  },
  "no": "staff001",
  "job_state": 1,
  "bank_no": "614457662",
  "social_security_no": "暂无",
  "on_job": true,
  "create_time": "2018-05-08 16:20:14",
  "update_time": "2018-05-08 16:20:14"
}
----------------------------------------------------
原始: 2
旧密码登陆结果: 2 17749503263
更新 {
  "user": {
    "id": 2,
    "name": "YKH1",
    "tel": "17749503263",
    "email": "",
    "password": "pbkdf2_sha256$100000$OCufCY1edUNM$nxQpQCU77q+HXh6ADuiWqcV9azR1uLThDsuO+ZuWEOQ=",
    "is_active": true,
    "is_staff": false,
    "groups": [
      1
    ]
  },
  "no": "17749503263",
  "job_state": 0,
  "bank_no": "614457662",
  "social_security_no": "",
  "on_job": true
}
新密码登陆结果: 2 YKH1

M2M直接嵌套新增

模型

# 出库
class StorageOut(models.Model):
    # 业务单号
    no = models.CharField(max_length=255,
                          unique=True,
                          verbose_name='业务单号')
    # 出库子记录
    items = models.ManyToManyField('storage.StorageOutItem',
                                   verbose_name='出库子记录')

# 出库子记录
class StorageOutItem(models.Model):
    # 库存
    stock = models.ForeignKey('storage.Stock',
                              related_name='storage_out_item_stock',
                              null=True,
                              on_delete=models.SET_NULL,
                              blank=True,
                              verbose_name='库存')
    # 出库数量
    out_num = models.PositiveIntegerField(default=0,
                                          verbose_name='出库数量')

视图

views.py

# 出库
class StorageOutViewSet(ModelViewSet, BulkUpdateModelMixin):
    """
    create:<code>{'storage': 仓库ID, 'category': 0, 'remarks': '出入备注',
        'items': [{'stock': 库存ID, 'out_num': 2}]}</code>
    """
    queryset = StorageOut.objects.all()
    serializer_class = StorageOutCreateSerializer

serializers.py

# 创建出库子记录
class StorageOutItemModifySerializer(ModelSerializer):
    class Meta:
        model = StorageOutItem
        fields = ('stock', 'out_num')
        extra_kwargs = {'stock': {'required': True}}

# 创建出库
class StorageOutCreateSerializer(WritableNestedModelSerializer):
    items = StorageOutItemModifySerializer(required=True, many=True, label='出库子记录')

    class Meta:
        model = StorageOut
        fields = ('no', 'items')

tests.py

class StorageOutTests(APITestCase):
    def setUp(self):
        ...

    def test_out(self):
        """出库"""
        response_out = self.client.post(url_out, {'no': 'no001, 'items': [{'stock': 1, 'out_num': 2}]},format='json')
        print('出库新增', json.dumps(response_out.data, ensure_ascii=False, indent=2))

测试输出:

出库新增 {
  "id": 1,
  "no": "NO001",
  "items": [
    {
      "stock": 1,
      "medicine_batch": {
        "id": 1,
        "medicine": {
          "id": 1,
          "name": "999感冒灵",
          "no": "NO999"
        },
      "stock_num": 3,
      "out_num": 2,
      "unit_price": "4.0000",
      "total_price": "8.0000",
      "create_time": "2018-05-08 16:36:58",
      "update_time": "2018-05-08 16:36:58"
    }
  ],
  "create_time": "2018-05-08 16:36:58",
  "update_time": "2018-05-08 16:36:58"
}

Note:在更新M2M字段时,有一部分不想修改也不想上传对应id的记录,则需要重写serializer里的update方法为:

def update(self, instance, validated_data):
    relations, reverse_relations = self._extract_relations(validated_data)
    # Create or update direct relations (foreign key, one-to-one)
    self.update_or_create_direct_relations(validated_data, relations, )
    # Update instance
    instance = super(PrescriptionCheckEscapeSerializer, self).update(instance, validated_data)
    self.update_or_create_reverse_relations(instance, reverse_relations)
    # 不删除旧的M2M数据
    # self.delete_reverse_relations_if_need(instance, reverse_relations)
    return instance

外键反向级联新增

模型

# 计量标准
class UnitGroup(models.Model):
    # 名称
    name = models.CharField(max_length=255,
                            verbose_name='名称')

# 计量单位
class Unit(models.Model):
    # 计量标准
    unit_group = models.ForeignKey('medicine.UnitGroup',
                                   related_name='unit_unit_group',
                                   null=True,
                                   on_delete=models.SET_NULL,
                                   verbose_name='计量标准')
    # 名称
    name = models.CharField(max_length=255,
                            verbose_name='名称')
    # 显示名称
    display_name = models.CharField(max_length=255,
                                    verbose_name='显示名称')

视图

views.py 同上

serializers.py:

# 创建计量单位
class UnitCreateSerializer(ModelSerializer):
    class Meta:
        model = Unit
        fields = ('name', 'display_name', 'ratio', 'is_active')

# 创建计量标准
class UnitGroupCreateSerializer(WritableNestedModelSerializer):
    unit_unit_group = UnitCreateSerializer(many=True)

    class Meta:
        model = UnitGroup
        fields = ('unit_unit_group', 'name')

tests.py:

class UnitGroupTests(APITestCase):
    def setUp(self):
        try:
            self.user = User.objects.get(tel='18094213198')
        except User.DoesNotExist:
            self.user = User.objects.create_user(tel='18094213198', password='123456')
        self.user.user_permissions.add(*get_model_permission(UnitGroup))
        token, _ = Token.objects.get_or_create(user=self.user)
        self.access_token = token.access_token

    def test_create(self):
        """创建"""
        url = reverse('unitgroup-list')
        self.client.credentials(HTTP_AUTHORIZATION='Token ' + self.access_token)
        data = {'name': '重量', 'unit_unit_group': [{'name': '1', 'display_name': 'yi'}]}
        response = self.client.post(url, data, format='json')
        print([(x, x.unit_group) for x in Unit.objects.all()])
        print('创建', json.dumps(response.data, ensure_ascii=False, indent=2))
        self.assertEqual(response.status_code, status.HTTP_200_OK)

测试输出:

[(<Unit: 1>, <UnitGroup: 重量>)]
创建 {
  "id": 1,
  "name": "重量",
  "create_time": "2018-05-08 16:59:51",
  "update_time": "2018-05-08 16:59:51"
}

unique=True

当模型中有unique=True的字段时,在嵌套更新时传递了和更新前相同的该字段,依旧提示该字段已存在的问题。 我们以外键或O2O直接嵌套新增中的模型来讲解,其更新时使用的serializer为:

# 更新员工
class StaffUpdateSerializer(WritableNestedModelSerializer):
    user = StaffUserCreateSerializer(label='用户', partial=True)

    class Meta:
        model = Staff
        fields = ('user', 'no', 'job_state', 'bank_no', 'social_security_no', 'on_job')

在python shell输入:

from user.models import User, Staff
staff = Staff.objects.first()
from user.serializers import StaffUpdateSerializer
serializer = StaffUpdateSerializer(staff, partial=True)
print(repr(serializer))
得到
StaffUpdateSerializer(<Staff: s>, partial=True):
user = StaffUserCreateSerializer(label='用户', partial=True):
    id = IntegerField(label='ID', read_only=True, required=False)
    name = CharField(allow_blank=True, label='姓名', max_length=30, required=False)
    tel = CharField(allow_null=True, label='手机号码', max_length=20, required=True, validators=[<UniqueValidator(queryset=User.objects.all())>])
    email = CharField(allow_blank=True, allow_null=True, label='电子邮件', max_length=255, required=False, validators=[<UniqueValidator(queryset=User.objects.all())>])
    password = CharField(label='密码', max_length=128, required=True)
    is_active = BooleanField(label='有效', required=False)
    is_staff = BooleanField(label='职员状态', required=False)
    groups = PrimaryKeyRelatedField(help_text='该用户归属的组。一个用户将得到其归属的组的所有权限。', label='组', many=True, queryset=Group.objects.all(), required=False)
no = CharField(label='工号', max_length=255, required=True, validators=[<UniqueValidator(queryset=Staff.objects.all())>])
job_state = ChoiceField(choices=dict_items([(0, '全职'), (1, '兼职')]), label='入职类型', required=False, validators=[<django.core.validators.MinValueValidator object>, <django.core.validators.MaxValueValidator object>])
bank_no = CharField(allow_blank=True, label='银行卡号码', max_length=30, required=False)
social_security_no = CharField(allow_blank=True, label='社保号', max_length=30, required=False)
on_job = BooleanField(label='是否在职', required=False)

可以看到StaffUserCreateSerializer中没有instance对象,所以tel字段的UniqueValidator范围为queryset=User.objects.all()。由此可见导致上述问题的主要原因是serializer在实例化的时候并没有把staff对应的user作为instance传递给StaffUserCreateSerializer

下面给出两种解决方案,有兴趣的可以查看:

方法1:在 rest_framework/validators.py/BaseUniqueForValidator 中有如下方法,当serializerinstance时,唯一验证的范围会排除当前instance,这就是为什么部分更新instance时上传instance原先的唯一值不会报错的原因。

def exclude_current_instance(self, attrs, queryset):
    """
    If an instance is being updated, then do not include
    that instance itself as a uniqueness conflict.
    """
    if self.instance is not None:
        return queryset.exclude(pk=self.instance.pk)
    return queryset

基于上面的思路,我们可以为UniqueValidator重新筛选范围:

class UserSerializer(ModelSerializer):
    # Reverse OneToOne relation
    profile = ProfileSerializer(required=False, allow_null=True)

    def to_internal_value(self, data):

        if 'pk' in data and 'pk' in self.fields:
            print(data)
            try:
                obj_id = self.fields['pk'].to_internal_value(data['pk'])
            except serializers.ValidationError as exc:
                raise serializers.ValidationError({'pk': exc.detail})
            for field in self.fields.values():
                for validator in field.validators:
                    from rest_framework.validators import UniqueValidator
                    if type(validator) == UniqueValidator:
                        # Exclude id from queryset for checking uniqueness
                        validator.queryset = validator.queryset.exclude(id=obj_id)
        return super(UserSerializer, self).to_internal_value(data)

方法2:将对应的id/pk转化为instance:

class StaffUserCreateSerializer(ModelSerializer):

    def to_internal_value(self, data):
        if 'pk' in data and 'pk' in self.fields:
            print(data)
            try:
                obj_id = self.fields['pk'].to_internal_value(data['pk'])
            except serializers.ValidationError as exc:
                raise serializers.ValidationError({'pk': exc.detail})
            self.instance = self.Meta.model.objects.get(id=obj_id)
        return super(StaffUserCreateSerializer, self).to_internal_value(data)

    class Meta:
        model = User
        fields = ('pk', 'name', 'tel',)

class StaffModifySerializer(WritableNestedModelSerializer):
    user = StaffUserCreateSerializer(label='User', partial=True)

    class Meta:
        model = Staff
        fields = ('user', 'job_state', 'bank_no', 'social_security_no', 'on_job')

相关Issue:

DRF issue2403 DRF issue 2996 drf-writable-nested issue 34