/*
 * Copyright 2016 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.keycloak.testsuite.model;

import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import org.jboss.logging.Logger;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.KeycloakSessionTask;
import org.keycloak.services.managers.DBLockManager;
import org.keycloak.models.dblock.DBLockProvider;
import org.keycloak.models.dblock.DBLockProviderFactory;
import org.keycloak.models.utils.KeycloakModelUtils;

/**
 * @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
 */
public class DBLockTest extends AbstractModelTest {

    private static final Logger log = Logger.getLogger(DBLockTest.class);

    private static final int SLEEP_TIME_MILLIS = 10;
    private static final int THREADS_COUNT = 20;
    private static final int ITERATIONS_PER_THREAD = 2;

    private static final int LOCK_TIMEOUT_MILLIS = 240000; // Rather bigger to handle slow DB connections in testing env
    private static final int LOCK_RECHECK_MILLIS = 10;

    @Before
    @Override
    public void before() throws Exception {
        super.before();

        // Set timeouts for testing
        DBLockManager lockManager = new DBLockManager();
        DBLockProviderFactory lockFactory = lockManager.getDBLockFactory(session);
        lockFactory.setTimeouts(LOCK_RECHECK_MILLIS, LOCK_TIMEOUT_MILLIS);

        // Drop lock table, just to simulate racing threads for create lock table and insert lock record into it.
        lockManager.getDBLock(session).destroyLockInfo();

        commit();
    }

    // @Test // TODO: Running -Dtest=DBLockTest,UserModelTest might cause issues sometimes. Reenable this once DB lock is refactored.
    public void testLockConcurrently() throws Exception {
        long startupTime = System.currentTimeMillis();

        final Semaphore semaphore = new Semaphore();
        final KeycloakSessionFactory sessionFactory = realmManager.getSession().getKeycloakSessionFactory();

        List<Thread> threads = new LinkedList<>();
        for (int i=0 ; i<THREADS_COUNT ; i++) {
            Thread thread = new Thread() {

                @Override
                public void run() {
                    for (int i=0 ; i<ITERATIONS_PER_THREAD ; i++) {
                        try {
                            KeycloakModelUtils.runJobInTransaction(sessionFactory, new KeycloakSessionTask() {

                                @Override
                                public void run(KeycloakSession session) {
                                    lock(session, semaphore);
                                }

                            });
                        } catch (RuntimeException e) {
                            semaphore.setException(e);
                            throw e;
                        }
                    }
                }

            };
            threads.add(thread);
        }

        for (Thread thread : threads) {
            thread.start();
        }
        for (Thread thread : threads) {
            thread.join();
        }

        long took = (System.currentTimeMillis() - startupTime);
        log.infof("DBLockTest executed in %d ms with total counter %d. THREADS_COUNT=%d, ITERATIONS_PER_THREAD=%d", took, semaphore.getTotal(), THREADS_COUNT, ITERATIONS_PER_THREAD);
        Assert.assertEquals(semaphore.getTotal(), THREADS_COUNT * ITERATIONS_PER_THREAD);
        Assert.assertNull(semaphore.getException());
    }

    private void lock(KeycloakSession session, Semaphore semaphore) {
        DBLockProvider dbLock = new DBLockManager().getDBLock(session);
        dbLock.waitForLock();
        try {
            semaphore.increase();
            Thread.sleep(SLEEP_TIME_MILLIS);
            semaphore.decrease();
        } catch (InterruptedException ie) {
            throw new RuntimeException(ie);
        } finally {
            dbLock.releaseLock();
        }
    }


    // Ensure just one thread is allowed to run at the same time
    private class Semaphore {

        private AtomicInteger counter = new AtomicInteger(0);
        private AtomicInteger totalIncreases = new AtomicInteger(0);

        private volatile Exception exception = null;

        private void increase() {
            int current = counter.incrementAndGet();
            if (current != 1) {
                IllegalStateException ex = new IllegalStateException("Counter has illegal value: " + current);
                setException(ex);
                throw ex;
            }
            totalIncreases.incrementAndGet();
        }

        private void decrease() {
            int current = counter.decrementAndGet();
            if (current != 0) {
                IllegalStateException ex = new IllegalStateException("Counter has illegal value: " + current);
                setException(ex);
                throw ex;
            }
        }

        private synchronized void setException(Exception exception) {
            if (this.exception == null) {
                this.exception = exception;
            }
        }

        private synchronized Exception getException() {
            return exception;
        }

        private int getTotal() {
            return totalIncreases.get();
        }
    }




}
