Merge branches 'arm/exynos', 'arm/renesas', 'arm/rockchip', 'arm/omap', 'arm/mediatek...
[sfrench/cifs-2.6.git] / drivers / iommu / arm-smmu.c
index 2d80fa8a0634aba34b366609d8bcc50f432bb31c..3bdb799d3b4b1f5ee1de1e2505d3d0024273c658 100644 (file)
 #include <linux/amba/bus.h>
 
 #include "io-pgtable.h"
+#include "arm-smmu-regs.h"
+
+#define ARM_MMU500_ACTLR_CPRE          (1 << 1)
+
+#define ARM_MMU500_ACR_CACHE_LOCK      (1 << 26)
+#define ARM_MMU500_ACR_SMTNMB_TLBEN    (1 << 8)
+
+#define TLB_LOOP_TIMEOUT               1000000 /* 1s! */
+#define TLB_SPIN_COUNT                 10
 
 /* Maximum number of context banks per SMMU */
 #define ARM_SMMU_MAX_CBS               128
 #define smmu_write_atomic_lq           writel_relaxed
 #endif
 
-/* Configuration registers */
-#define ARM_SMMU_GR0_sCR0              0x0
-#define sCR0_CLIENTPD                  (1 << 0)
-#define sCR0_GFRE                      (1 << 1)
-#define sCR0_GFIE                      (1 << 2)
-#define sCR0_EXIDENABLE                        (1 << 3)
-#define sCR0_GCFGFRE                   (1 << 4)
-#define sCR0_GCFGFIE                   (1 << 5)
-#define sCR0_USFCFG                    (1 << 10)
-#define sCR0_VMIDPNE                   (1 << 11)
-#define sCR0_PTM                       (1 << 12)
-#define sCR0_FB                                (1 << 13)
-#define sCR0_VMID16EN                  (1 << 31)
-#define sCR0_BSU_SHIFT                 14
-#define sCR0_BSU_MASK                  0x3
-
-/* Auxiliary Configuration register */
-#define ARM_SMMU_GR0_sACR              0x10
-
-/* Identification registers */
-#define ARM_SMMU_GR0_ID0               0x20
-#define ARM_SMMU_GR0_ID1               0x24
-#define ARM_SMMU_GR0_ID2               0x28
-#define ARM_SMMU_GR0_ID3               0x2c
-#define ARM_SMMU_GR0_ID4               0x30
-#define ARM_SMMU_GR0_ID5               0x34
-#define ARM_SMMU_GR0_ID6               0x38
-#define ARM_SMMU_GR0_ID7               0x3c
-#define ARM_SMMU_GR0_sGFSR             0x48
-#define ARM_SMMU_GR0_sGFSYNR0          0x50
-#define ARM_SMMU_GR0_sGFSYNR1          0x54
-#define ARM_SMMU_GR0_sGFSYNR2          0x58
-
-#define ID0_S1TS                       (1 << 30)
-#define ID0_S2TS                       (1 << 29)
-#define ID0_NTS                                (1 << 28)
-#define ID0_SMS                                (1 << 27)
-#define ID0_ATOSNS                     (1 << 26)
-#define ID0_PTFS_NO_AARCH32            (1 << 25)
-#define ID0_PTFS_NO_AARCH32S           (1 << 24)
-#define ID0_CTTW                       (1 << 14)
-#define ID0_NUMIRPT_SHIFT              16
-#define ID0_NUMIRPT_MASK               0xff
-#define ID0_NUMSIDB_SHIFT              9
-#define ID0_NUMSIDB_MASK               0xf
-#define ID0_EXIDS                      (1 << 8)
-#define ID0_NUMSMRG_SHIFT              0
-#define ID0_NUMSMRG_MASK               0xff
-
-#define ID1_PAGESIZE                   (1 << 31)
-#define ID1_NUMPAGENDXB_SHIFT          28
-#define ID1_NUMPAGENDXB_MASK           7
-#define ID1_NUMS2CB_SHIFT              16
-#define ID1_NUMS2CB_MASK               0xff
-#define ID1_NUMCB_SHIFT                        0
-#define ID1_NUMCB_MASK                 0xff
-
-#define ID2_OAS_SHIFT                  4
-#define ID2_OAS_MASK                   0xf
-#define ID2_IAS_SHIFT                  0
-#define ID2_IAS_MASK                   0xf
-#define ID2_UBS_SHIFT                  8
-#define ID2_UBS_MASK                   0xf
-#define ID2_PTFS_4K                    (1 << 12)
-#define ID2_PTFS_16K                   (1 << 13)
-#define ID2_PTFS_64K                   (1 << 14)
-#define ID2_VMID16                     (1 << 15)
-
-#define ID7_MAJOR_SHIFT                        4
-#define ID7_MAJOR_MASK                 0xf
-
-/* Global TLB invalidation */
-#define ARM_SMMU_GR0_TLBIVMID          0x64
-#define ARM_SMMU_GR0_TLBIALLNSNH       0x68
-#define ARM_SMMU_GR0_TLBIALLH          0x6c
-#define ARM_SMMU_GR0_sTLBGSYNC         0x70
-#define ARM_SMMU_GR0_sTLBGSTATUS       0x74
-#define sTLBGSTATUS_GSACTIVE           (1 << 0)
-#define TLB_LOOP_TIMEOUT               1000000 /* 1s! */
-#define TLB_SPIN_COUNT                 10
-
-/* Stream mapping registers */
-#define ARM_SMMU_GR0_SMR(n)            (0x800 + ((n) << 2))
-#define SMR_VALID                      (1 << 31)
-#define SMR_MASK_SHIFT                 16
-#define SMR_ID_SHIFT                   0
-
-#define ARM_SMMU_GR0_S2CR(n)           (0xc00 + ((n) << 2))
-#define S2CR_CBNDX_SHIFT               0
-#define S2CR_CBNDX_MASK                        0xff
-#define S2CR_EXIDVALID                 (1 << 10)
-#define S2CR_TYPE_SHIFT                        16
-#define S2CR_TYPE_MASK                 0x3
-enum arm_smmu_s2cr_type {
-       S2CR_TYPE_TRANS,
-       S2CR_TYPE_BYPASS,
-       S2CR_TYPE_FAULT,
-};
-
-#define S2CR_PRIVCFG_SHIFT             24
-#define S2CR_PRIVCFG_MASK              0x3
-enum arm_smmu_s2cr_privcfg {
-       S2CR_PRIVCFG_DEFAULT,
-       S2CR_PRIVCFG_DIPAN,
-       S2CR_PRIVCFG_UNPRIV,
-       S2CR_PRIVCFG_PRIV,
-};
-
-/* Context bank attribute registers */
-#define ARM_SMMU_GR1_CBAR(n)           (0x0 + ((n) << 2))
-#define CBAR_VMID_SHIFT                        0
-#define CBAR_VMID_MASK                 0xff
-#define CBAR_S1_BPSHCFG_SHIFT          8
-#define CBAR_S1_BPSHCFG_MASK           3
-#define CBAR_S1_BPSHCFG_NSH            3
-#define CBAR_S1_MEMATTR_SHIFT          12
-#define CBAR_S1_MEMATTR_MASK           0xf
-#define CBAR_S1_MEMATTR_WB             0xf
-#define CBAR_TYPE_SHIFT                        16
-#define CBAR_TYPE_MASK                 0x3
-#define CBAR_TYPE_S2_TRANS             (0 << CBAR_TYPE_SHIFT)
-#define CBAR_TYPE_S1_TRANS_S2_BYPASS   (1 << CBAR_TYPE_SHIFT)
-#define CBAR_TYPE_S1_TRANS_S2_FAULT    (2 << CBAR_TYPE_SHIFT)
-#define CBAR_TYPE_S1_TRANS_S2_TRANS    (3 << CBAR_TYPE_SHIFT)
-#define CBAR_IRPTNDX_SHIFT             24
-#define CBAR_IRPTNDX_MASK              0xff
-
-#define ARM_SMMU_GR1_CBA2R(n)          (0x800 + ((n) << 2))
-#define CBA2R_RW64_32BIT               (0 << 0)
-#define CBA2R_RW64_64BIT               (1 << 0)
-#define CBA2R_VMID_SHIFT               16
-#define CBA2R_VMID_MASK                        0xffff
-
 /* Translation context bank */
 #define ARM_SMMU_CB(smmu, n)   ((smmu)->cb_base + ((n) << (smmu)->pgshift))
 
-#define ARM_SMMU_CB_SCTLR              0x0
-#define ARM_SMMU_CB_ACTLR              0x4
-#define ARM_SMMU_CB_RESUME             0x8
-#define ARM_SMMU_CB_TTBCR2             0x10
-#define ARM_SMMU_CB_TTBR0              0x20
-#define ARM_SMMU_CB_TTBR1              0x28
-#define ARM_SMMU_CB_TTBCR              0x30
-#define ARM_SMMU_CB_CONTEXTIDR         0x34
-#define ARM_SMMU_CB_S1_MAIR0           0x38
-#define ARM_SMMU_CB_S1_MAIR1           0x3c
-#define ARM_SMMU_CB_PAR                        0x50
-#define ARM_SMMU_CB_FSR                        0x58
-#define ARM_SMMU_CB_FAR                        0x60
-#define ARM_SMMU_CB_FSYNR0             0x68
-#define ARM_SMMU_CB_S1_TLBIVA          0x600
-#define ARM_SMMU_CB_S1_TLBIASID                0x610
-#define ARM_SMMU_CB_S1_TLBIVAL         0x620
-#define ARM_SMMU_CB_S2_TLBIIPAS2       0x630
-#define ARM_SMMU_CB_S2_TLBIIPAS2L      0x638
-#define ARM_SMMU_CB_TLBSYNC            0x7f0
-#define ARM_SMMU_CB_TLBSTATUS          0x7f4
-#define ARM_SMMU_CB_ATS1PR             0x800
-#define ARM_SMMU_CB_ATSR               0x8f0
-
-#define SCTLR_S1_ASIDPNE               (1 << 12)
-#define SCTLR_CFCFG                    (1 << 7)
-#define SCTLR_CFIE                     (1 << 6)
-#define SCTLR_CFRE                     (1 << 5)
-#define SCTLR_E                                (1 << 4)
-#define SCTLR_AFE                      (1 << 2)
-#define SCTLR_TRE                      (1 << 1)
-#define SCTLR_M                                (1 << 0)
-
-#define ARM_MMU500_ACTLR_CPRE          (1 << 1)
-
-#define ARM_MMU500_ACR_CACHE_LOCK      (1 << 26)
-#define ARM_MMU500_ACR_SMTNMB_TLBEN    (1 << 8)
-
-#define CB_PAR_F                       (1 << 0)
-
-#define ATSR_ACTIVE                    (1 << 0)
-
-#define RESUME_RETRY                   (0 << 0)
-#define RESUME_TERMINATE               (1 << 0)
-
-#define TTBCR2_SEP_SHIFT               15
-#define TTBCR2_SEP_UPSTREAM            (0x7 << TTBCR2_SEP_SHIFT)
-#define TTBCR2_AS                      (1 << 4)
-
-#define TTBRn_ASID_SHIFT               48
-
-#define FSR_MULTI                      (1 << 31)
-#define FSR_SS                         (1 << 30)
-#define FSR_UUT                                (1 << 8)
-#define FSR_ASF                                (1 << 7)
-#define FSR_TLBLKF                     (1 << 6)
-#define FSR_TLBMCF                     (1 << 5)
-#define FSR_EF                         (1 << 4)
-#define FSR_PF                         (1 << 3)
-#define FSR_AFF                                (1 << 2)
-#define FSR_TF                         (1 << 1)
-
-#define FSR_IGN                                (FSR_AFF | FSR_ASF | \
-                                        FSR_TLBMCF | FSR_TLBLKF)
-#define FSR_FAULT                      (FSR_MULTI | FSR_SS | FSR_UUT | \
-                                        FSR_EF | FSR_PF | FSR_TF | FSR_IGN)
-
-#define FSYNR0_WNR                     (1 << 4)
-
 #define MSI_IOVA_BASE                  0x8000000
 #define MSI_IOVA_LENGTH                        0x100000
 
@@ -338,6 +145,13 @@ struct arm_smmu_smr {
        bool                            valid;
 };
 
+struct arm_smmu_cb {
+       u64                             ttbr[2];
+       u32                             tcr[2];
+       u32                             mair[2];
+       struct arm_smmu_cfg             *cfg;
+};
+
 struct arm_smmu_master_cfg {
        struct arm_smmu_device          *smmu;
        s16                             smendx[];
@@ -380,6 +194,7 @@ struct arm_smmu_device {
        u32                             num_context_banks;
        u32                             num_s2_context_banks;
        DECLARE_BITMAP(context_map, ARM_SMMU_MAX_CBS);
+       struct arm_smmu_cb              *cbs;
        atomic_t                        irptndx;
 
        u32                             num_mapping_groups;
@@ -776,17 +591,74 @@ static irqreturn_t arm_smmu_global_fault(int irq, void *dev)
 static void arm_smmu_init_context_bank(struct arm_smmu_domain *smmu_domain,
                                       struct io_pgtable_cfg *pgtbl_cfg)
 {
-       u32 reg, reg2;
-       u64 reg64;
-       bool stage1;
        struct arm_smmu_cfg *cfg = &smmu_domain->cfg;
-       struct arm_smmu_device *smmu = smmu_domain->smmu;
+       struct arm_smmu_cb *cb = &smmu_domain->smmu->cbs[cfg->cbndx];
+       bool stage1 = cfg->cbar != CBAR_TYPE_S2_TRANS;
+
+       cb->cfg = cfg;
+
+       /* TTBCR */
+       if (stage1) {
+               if (cfg->fmt == ARM_SMMU_CTX_FMT_AARCH32_S) {
+                       cb->tcr[0] = pgtbl_cfg->arm_v7s_cfg.tcr;
+               } else {
+                       cb->tcr[0] = pgtbl_cfg->arm_lpae_s1_cfg.tcr;
+                       cb->tcr[1] = pgtbl_cfg->arm_lpae_s1_cfg.tcr >> 32;
+                       cb->tcr[1] |= TTBCR2_SEP_UPSTREAM;
+                       if (cfg->fmt == ARM_SMMU_CTX_FMT_AARCH64)
+                               cb->tcr[1] |= TTBCR2_AS;
+               }
+       } else {
+               cb->tcr[0] = pgtbl_cfg->arm_lpae_s2_cfg.vtcr;
+       }
+
+       /* TTBRs */
+       if (stage1) {
+               if (cfg->fmt == ARM_SMMU_CTX_FMT_AARCH32_S) {
+                       cb->ttbr[0] = pgtbl_cfg->arm_v7s_cfg.ttbr[0];
+                       cb->ttbr[1] = pgtbl_cfg->arm_v7s_cfg.ttbr[1];
+               } else {
+                       cb->ttbr[0] = pgtbl_cfg->arm_lpae_s1_cfg.ttbr[0];
+                       cb->ttbr[0] |= (u64)cfg->asid << TTBRn_ASID_SHIFT;
+                       cb->ttbr[1] = pgtbl_cfg->arm_lpae_s1_cfg.ttbr[1];
+                       cb->ttbr[1] |= (u64)cfg->asid << TTBRn_ASID_SHIFT;
+               }
+       } else {
+               cb->ttbr[0] = pgtbl_cfg->arm_lpae_s2_cfg.vttbr;
+       }
+
+       /* MAIRs (stage-1 only) */
+       if (stage1) {
+               if (cfg->fmt == ARM_SMMU_CTX_FMT_AARCH32_S) {
+                       cb->mair[0] = pgtbl_cfg->arm_v7s_cfg.prrr;
+                       cb->mair[1] = pgtbl_cfg->arm_v7s_cfg.nmrr;
+               } else {
+                       cb->mair[0] = pgtbl_cfg->arm_lpae_s1_cfg.mair[0];
+                       cb->mair[1] = pgtbl_cfg->arm_lpae_s1_cfg.mair[1];
+               }
+       }
+}
+
+static void arm_smmu_write_context_bank(struct arm_smmu_device *smmu, int idx)
+{
+       u32 reg;
+       bool stage1;
+       struct arm_smmu_cb *cb = &smmu->cbs[idx];
+       struct arm_smmu_cfg *cfg = cb->cfg;
        void __iomem *cb_base, *gr1_base;
 
+       cb_base = ARM_SMMU_CB(smmu, idx);
+
+       /* Unassigned context banks only need disabling */
+       if (!cfg) {
+               writel_relaxed(0, cb_base + ARM_SMMU_CB_SCTLR);
+               return;
+       }
+
        gr1_base = ARM_SMMU_GR1(smmu);
        stage1 = cfg->cbar != CBAR_TYPE_S2_TRANS;
-       cb_base = ARM_SMMU_CB(smmu, cfg->cbndx);
 
+       /* CBA2R */
        if (smmu->version > ARM_SMMU_V1) {
                if (cfg->fmt == ARM_SMMU_CTX_FMT_AARCH64)
                        reg = CBA2R_RW64_64BIT;
@@ -796,7 +668,7 @@ static void arm_smmu_init_context_bank(struct arm_smmu_domain *smmu_domain,
                if (smmu->features & ARM_SMMU_FEAT_VMID16)
                        reg |= cfg->vmid << CBA2R_VMID_SHIFT;
 
-               writel_relaxed(reg, gr1_base + ARM_SMMU_GR1_CBA2R(cfg->cbndx));
+               writel_relaxed(reg, gr1_base + ARM_SMMU_GR1_CBA2R(idx));
        }
 
        /* CBAR */
@@ -815,72 +687,41 @@ static void arm_smmu_init_context_bank(struct arm_smmu_domain *smmu_domain,
                /* 8-bit VMIDs live in CBAR */
                reg |= cfg->vmid << CBAR_VMID_SHIFT;
        }
-       writel_relaxed(reg, gr1_base + ARM_SMMU_GR1_CBAR(cfg->cbndx));
+       writel_relaxed(reg, gr1_base + ARM_SMMU_GR1_CBAR(idx));
 
        /*
         * TTBCR
         * We must write this before the TTBRs, since it determines the
         * access behaviour of some fields (in particular, ASID[15:8]).
         */
-       if (stage1) {
-               if (cfg->fmt == ARM_SMMU_CTX_FMT_AARCH32_S) {
-                       reg = pgtbl_cfg->arm_v7s_cfg.tcr;
-                       reg2 = 0;
-               } else {
-                       reg = pgtbl_cfg->arm_lpae_s1_cfg.tcr;
-                       reg2 = pgtbl_cfg->arm_lpae_s1_cfg.tcr >> 32;
-                       reg2 |= TTBCR2_SEP_UPSTREAM;
-                       if (cfg->fmt == ARM_SMMU_CTX_FMT_AARCH64)
-                               reg2 |= TTBCR2_AS;
-               }
-               if (smmu->version > ARM_SMMU_V1)
-                       writel_relaxed(reg2, cb_base + ARM_SMMU_CB_TTBCR2);
-       } else {
-               reg = pgtbl_cfg->arm_lpae_s2_cfg.vtcr;
-       }
-       writel_relaxed(reg, cb_base + ARM_SMMU_CB_TTBCR);
+       if (stage1 && smmu->version > ARM_SMMU_V1)
+               writel_relaxed(cb->tcr[1], cb_base + ARM_SMMU_CB_TTBCR2);
+       writel_relaxed(cb->tcr[0], cb_base + ARM_SMMU_CB_TTBCR);
 
        /* TTBRs */
-       if (stage1) {
-               if (cfg->fmt == ARM_SMMU_CTX_FMT_AARCH32_S) {
-                       reg = pgtbl_cfg->arm_v7s_cfg.ttbr[0];
-                       writel_relaxed(reg, cb_base + ARM_SMMU_CB_TTBR0);
-                       reg = pgtbl_cfg->arm_v7s_cfg.ttbr[1];
-                       writel_relaxed(reg, cb_base + ARM_SMMU_CB_TTBR1);
-                       writel_relaxed(cfg->asid, cb_base + ARM_SMMU_CB_CONTEXTIDR);
-               } else {
-                       reg64 = pgtbl_cfg->arm_lpae_s1_cfg.ttbr[0];
-                       reg64 |= (u64)cfg->asid << TTBRn_ASID_SHIFT;
-                       writeq_relaxed(reg64, cb_base + ARM_SMMU_CB_TTBR0);
-                       reg64 = pgtbl_cfg->arm_lpae_s1_cfg.ttbr[1];
-                       reg64 |= (u64)cfg->asid << TTBRn_ASID_SHIFT;
-                       writeq_relaxed(reg64, cb_base + ARM_SMMU_CB_TTBR1);
-               }
+       if (cfg->fmt == ARM_SMMU_CTX_FMT_AARCH32_S) {
+               writel_relaxed(cfg->asid, cb_base + ARM_SMMU_CB_CONTEXTIDR);
+               writel_relaxed(cb->ttbr[0], cb_base + ARM_SMMU_CB_TTBR0);
+               writel_relaxed(cb->ttbr[1], cb_base + ARM_SMMU_CB_TTBR1);
        } else {
-               reg64 = pgtbl_cfg->arm_lpae_s2_cfg.vttbr;
-               writeq_relaxed(reg64, cb_base + ARM_SMMU_CB_TTBR0);
+               writeq_relaxed(cb->ttbr[0], cb_base + ARM_SMMU_CB_TTBR0);
+               if (stage1)
+                       writeq_relaxed(cb->ttbr[1], cb_base + ARM_SMMU_CB_TTBR1);
        }
 
        /* MAIRs (stage-1 only) */
        if (stage1) {
-               if (cfg->fmt == ARM_SMMU_CTX_FMT_AARCH32_S) {
-                       reg = pgtbl_cfg->arm_v7s_cfg.prrr;
-                       reg2 = pgtbl_cfg->arm_v7s_cfg.nmrr;
-               } else {
-                       reg = pgtbl_cfg->arm_lpae_s1_cfg.mair[0];
-                       reg2 = pgtbl_cfg->arm_lpae_s1_cfg.mair[1];
-               }
-               writel_relaxed(reg, cb_base + ARM_SMMU_CB_S1_MAIR0);
-               writel_relaxed(reg2, cb_base + ARM_SMMU_CB_S1_MAIR1);
+               writel_relaxed(cb->mair[0], cb_base + ARM_SMMU_CB_S1_MAIR0);
+               writel_relaxed(cb->mair[1], cb_base + ARM_SMMU_CB_S1_MAIR1);
        }
 
        /* SCTLR */
        reg = SCTLR_CFIE | SCTLR_CFRE | SCTLR_AFE | SCTLR_TRE | SCTLR_M;
        if (stage1)
                reg |= SCTLR_S1_ASIDPNE;
-#ifdef __BIG_ENDIAN
-       reg |= SCTLR_E;
-#endif
+       if (IS_ENABLED(CONFIG_CPU_BIG_ENDIAN))
+               reg |= SCTLR_E;
+
        writel_relaxed(reg, cb_base + ARM_SMMU_CB_SCTLR);
 }
 
@@ -1043,6 +884,7 @@ static int arm_smmu_init_domain_context(struct iommu_domain *domain,
 
        /* Initialise the context bank with our page table cfg */
        arm_smmu_init_context_bank(smmu_domain, &pgtbl_cfg);
+       arm_smmu_write_context_bank(smmu, cfg->cbndx);
 
        /*
         * Request context fault interrupt. Do this last to avoid the
@@ -1075,7 +917,6 @@ static void arm_smmu_destroy_domain_context(struct iommu_domain *domain)
        struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
        struct arm_smmu_device *smmu = smmu_domain->smmu;
        struct arm_smmu_cfg *cfg = &smmu_domain->cfg;
-       void __iomem *cb_base;
        int irq;
 
        if (!smmu || domain->type == IOMMU_DOMAIN_IDENTITY)
@@ -1085,8 +926,8 @@ static void arm_smmu_destroy_domain_context(struct iommu_domain *domain)
         * Disable the context bank and free the page tables before freeing
         * it.
         */
-       cb_base = ARM_SMMU_CB(smmu, cfg->cbndx);
-       writel_relaxed(0, cb_base + ARM_SMMU_CB_SCTLR);
+       smmu->cbs[cfg->cbndx].cfg = NULL;
+       arm_smmu_write_context_bank(smmu, cfg->cbndx);
 
        if (cfg->irptndx != INVALID_IRPTNDX) {
                irq = smmu->irqs[smmu->num_global_irqs + cfg->irptndx];
@@ -1736,7 +1577,6 @@ static struct iommu_ops arm_smmu_ops = {
 static void arm_smmu_device_reset(struct arm_smmu_device *smmu)
 {
        void __iomem *gr0_base = ARM_SMMU_GR0(smmu);
-       void __iomem *cb_base;
        int i;
        u32 reg, major;
 
@@ -1772,8 +1612,9 @@ static void arm_smmu_device_reset(struct arm_smmu_device *smmu)
 
        /* Make sure all context banks are disabled and clear CB_FSR  */
        for (i = 0; i < smmu->num_context_banks; ++i) {
-               cb_base = ARM_SMMU_CB(smmu, i);
-               writel_relaxed(0, cb_base + ARM_SMMU_CB_SCTLR);
+               void __iomem *cb_base = ARM_SMMU_CB(smmu, i);
+
+               arm_smmu_write_context_bank(smmu, i);
                writel_relaxed(FSR_FAULT, cb_base + ARM_SMMU_CB_FSR);
                /*
                 * Disable MMU-500's not-particularly-beneficial next-page
@@ -1979,6 +1820,10 @@ static int arm_smmu_device_cfg_probe(struct arm_smmu_device *smmu)
                smmu->cavium_id_base -= smmu->num_context_banks;
                dev_notice(smmu->dev, "\tenabling workaround for Cavium erratum 27704\n");
        }
+       smmu->cbs = devm_kcalloc(smmu->dev, smmu->num_context_banks,
+                                sizeof(*smmu->cbs), GFP_KERNEL);
+       if (!smmu->cbs)
+               return -ENOMEM;
 
        /* ID2 */
        id = readl_relaxed(gr0_base + ARM_SMMU_GR0_ID2);
@@ -2336,13 +2181,30 @@ static int arm_smmu_device_remove(struct platform_device *pdev)
        return 0;
 }
 
+static void arm_smmu_device_shutdown(struct platform_device *pdev)
+{
+       arm_smmu_device_remove(pdev);
+}
+
+static int __maybe_unused arm_smmu_pm_resume(struct device *dev)
+{
+       struct arm_smmu_device *smmu = dev_get_drvdata(dev);
+
+       arm_smmu_device_reset(smmu);
+       return 0;
+}
+
+static SIMPLE_DEV_PM_OPS(arm_smmu_pm_ops, NULL, arm_smmu_pm_resume);
+
 static struct platform_driver arm_smmu_driver = {
        .driver = {
                .name           = "arm-smmu",
                .of_match_table = of_match_ptr(arm_smmu_of_match),
+               .pm             = &arm_smmu_pm_ops,
        },
        .probe  = arm_smmu_device_probe,
        .remove = arm_smmu_device_remove,
+       .shutdown = arm_smmu_device_shutdown,
 };
 module_platform_driver(arm_smmu_driver);