diff --git a/api/serializers.py b/api/serializers.py index 38c2861..0d0e8c3 100755 --- a/api/serializers.py +++ b/api/serializers.py @@ -2,13 +2,39 @@ from django.contrib.auth.models import Group from django.contrib.auth import get_user_model from rest_framework import serializers from qrtr_account.models import Account, Bank, Institution, Transaction, Slice, Rule +from user.models import User from connection.models import Connection, ConnectionType -class UserSerializer(serializers.HyperlinkedModelSerializer): +class UserAccountSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = get_user_model() - fields = ['url', 'username', 'email', 'groups', 'owned_accounts', + fields = ['pk', 'url', 'username', 'email', 'groups'] + + +class AccountReadSerializer(serializers.HyperlinkedModelSerializer): + owner = UserAccountSerializer(read_only=True) + admin_users = UserAccountSerializer(many=True, read_only=True) + view_users = UserAccountSerializer(many=True, read_only=True) + class Meta: + model = Account + fields = ['pk', 'url', 'owner', 'name', 'admin_users', 'view_users'] + +class AccountWriteSerializer(serializers.HyperlinkedModelSerializer): + owner = serializers.SlugRelatedField(slug_field='pk', queryset=User.objects.all()) + admin_users = serializers.SlugRelatedField(slug_field='pk', queryset=User.objects.all(), many=True) + view_users = serializers.SlugRelatedField(slug_field='pk', queryset=User.objects.all(), many=True) + class Meta: + model = Account + fields = ['pk', 'url', 'owner', 'name', 'admin_users', 'view_users'] + +class UserSerializer(serializers.HyperlinkedModelSerializer): + owned_accounts = AccountReadSerializer(many=True, read_only=True) + admin_accounts = AccountReadSerializer(many=True, read_only=True) + view_accounts = AccountReadSerializer(many=True, read_only=True) + class Meta: + model = get_user_model() + fields = ['pk', 'url', 'username', 'email', 'groups', 'owned_accounts', 'admin_accounts', 'view_accounts'] @@ -18,12 +44,6 @@ class GroupSerializer(serializers.HyperlinkedModelSerializer): fields = ['url', 'name'] -class AccountSerializer(serializers.HyperlinkedModelSerializer): - class Meta: - model = Account - fields = ['url', 'owner', 'name', 'admin_users', 'view_users'] - - class ConnectionTypeSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ConnectionType diff --git a/connection/connections/plaid_client.py b/connection/connections/plaid_client.py index bb3445c..b73e66c 100755 --- a/connection/connections/plaid_client.py +++ b/connection/connections/plaid_client.py @@ -44,11 +44,14 @@ class Connection(AbstractConnectionClient): public_key=self.PLAID_PUBLIC_KEY, environment=self.PLAID_ENV, api_version='2019-05-29') - print("Getting Public Key") public_key = self.credentials.get('public_token') - print("Retrieving auth token") - if not self.credentials.get('auth_token') and public_key: - self.credentials['auth_token'] = self.get_auth_token(public_key) + auth_token = self.credentials.get('auth_token') + if not auth_token and public_key: + print("Getting Auth Token From Public Key") + auth_token = self.get_auth_token(public_key) + if "error" in auth_token: + raise ValueError(f"Unable to generate Auth Token, {auth_token}") + self.credentials['auth_token'] = auth_token print("Plaid Connection successful") print(self.credentials) diff --git a/connection/views.py b/connection/views.py index 8e467f3..79fd885 100644 --- a/connection/views.py +++ b/connection/views.py @@ -60,18 +60,4 @@ class ConnectionViewSet(viewsets.ModelViewSet): plaid_client = plaid.Connection(request.data) conn.credentials = plaid_client.credentials conn.save() - return Response(plaid_client.get_accounts()) - - @action(detail=False, methods=['get'], url_path='accounts') - def get_accounts(self,request): - print("GETTING ACCOUNTS!") - print(request.user) - connections = [] - user_qrtr_accounts = request.user.owned_accounts.all() | \ - request.user.admin_accounts.all() | \ - request.user.view_accounts.all() - for qrtr_account in user_qrtr_accounts: - connections = qrtr_account.connection__set.all() - for connection in connections: - connections.append(connection.get_accounts()) - return Response(200) \ No newline at end of file + return Response(plaid_client.get_accounts()) \ No newline at end of file diff --git a/qrtr_account/views.py b/qrtr_account/views.py index 6f0eaec..4001520 100644 --- a/qrtr_account/views.py +++ b/qrtr_account/views.py @@ -2,7 +2,7 @@ from django.shortcuts import render from rest_framework import viewsets, mixins from .models import Account, Bank, Institution, Transaction, Slice, Rule from connection.models import Connection, ConnectionType -from api.serializers import (AccountSerializer, +from api.serializers import (AccountReadSerializer, AccountWriteSerializer, BankSerializer, BankSerializerPOST, InstitutionSerializer, TransactionSerializer, @@ -14,6 +14,7 @@ from allauth.socialaccount.providers.facebook.views import FacebookOAuth2Adapter from dj_rest_auth.registration.views import SocialLoginView from allauth.socialaccount.providers.twitter.views import TwitterOAuthAdapter from dj_rest_auth.social_serializers import TwitterLoginSerializer +from api.mixins import ReadWriteSerializerMixin class TwitterLogin(SocialLoginView): @@ -25,11 +26,12 @@ class FacebookLogin(SocialLoginView): adapter_class = FacebookOAuth2Adapter -class AccountViewSet(viewsets.ModelViewSet): +class AccountViewSet(ReadWriteSerializerMixin, viewsets.ModelViewSet): """API endpoint that allows accounts to be viewed or edited """ queryset = Account.objects.all() - serializer_class = AccountSerializer + read_serializer_class = AccountReadSerializer + write_serializer_class = AccountWriteSerializer class BankViewSet(viewsets.ModelViewSet): diff --git a/user/views.py b/user/views.py index ac0448d..2a89eea 100644 --- a/user/views.py +++ b/user/views.py @@ -7,8 +7,11 @@ from allauth.account.views import ConfirmEmailView from django.shortcuts import redirect, render from django.http import Http404 from django.views.generic.base import TemplateView +from rest_framework.response import Response +from rest_framework.decorators import action - +import importlib +from collections import defaultdict class UserViewSet(viewsets.ModelViewSet): """ @@ -17,6 +20,32 @@ class UserViewSet(viewsets.ModelViewSet): queryset = get_user_model().objects.all().order_by('-date_joined') serializer_class = UserSerializer + @action(detail=False, methods=['get'], url_path="me") + def me(self, request, pk=None): + user = UserSerializer(request.user, context={'request': request}) + return Response(user.data) + + @action(detail=False, methods=['get'], url_path='list-connections') + def get_accounts(self,request): + print("GETTING ACCOUNTS!") + print(request.user) + avail_conns = defaultdict(list) + user_qrtr_accounts = request.user.owned_accounts.all() | \ + request.user.admin_accounts.all() | \ + request.user.view_accounts.all() + for qrtr_account in user_qrtr_accounts.distinct(): + connections = qrtr_account.connection_set.all() + for connection in connections: + conn_name = connection.type.name + conn_accs = [] + client_lib = importlib.import_module(f"connection.connections.{connection.type.filename}") + client = client_lib.Connection(connection.credentials) + connection.credentials = client.credentials + connection.save() + conn_accs.append(client.get_accounts()) + avail_conns[conn_name].extend(conn_accs) + return Response(avail_conns) + class GroupViewSet(viewsets.ReadOnlyModelViewSet): """