package com.kidgrow.authclient.store;
import com.kidgrow.authclient.properties.SecurityProperties;
import com.kidgrow.common.constant.SecurityConstants;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.ExpiringOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.provider.ClientDetails;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.TokenStore;
import org.springframework.security.oauth2.provider.token.store.redis.JdkSerializationStrategy;
import org.springframework.security.oauth2.provider.token.store.redis.RedisTokenStore;
import org.springframework.security.oauth2.provider.token.store.redis.RedisTokenStoreSerializationStrategy;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import java.lang.reflect.Method;
import java.util.*;
/**
* 石家庄喜高科技有限责任公司 版权所有 © Copyright 2020
* @Description: 优化自Spring Security的RedisTokenStore
* 1. 支持redis所有集群模式包括cluster模式
* 2. 使用pipeline减少连接次数,提升性能
* 3. 自动续签token(可配置是否开启)
* @Project:
* @CreateDate: Created in 2020/2/13 11:29
* @Author: liuke
*/
public class CustomRedisTokenStore implements TokenStore {
private static final String ACCESS = "access:";
private static final String AUTH_TO_ACCESS = "auth_to_access:";
private static final String REFRESH_AUTH = "refresh_auth:";
private static final String ACCESS_TO_REFRESH = "access_to_refresh:";
private static final String REFRESH = "refresh:";
private static final String REFRESH_TO_ACCESS = "refresh_to_access:";
private static final boolean springDataRedis_2_0 = ClassUtils.isPresent(
"org.springframework.data.redis.connection.RedisStandaloneConfiguration",
RedisTokenStore.class.getClassLoader());
private final RedisConnectionFactory connectionFactory;
private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();
private RedisTokenStoreSerializationStrategy serializationStrategy = new JdkSerializationStrategy();
private String prefix = "";
private Method redisConnectionSet_2_0;
/**
* 认证配置
*/
private SecurityProperties securityProperties;
public CustomRedisTokenStore(RedisConnectionFactory connectionFactory, SecurityProperties securityProperties) {
this.connectionFactory = connectionFactory;
this.securityProperties = securityProperties;
if (springDataRedis_2_0) {
this.loadRedisConnectionMethods_2_0();
}
}
public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
this.authenticationKeyGenerator = authenticationKeyGenerator;
}
public void setSerializationStrategy(RedisTokenStoreSerializationStrategy serializationStrategy) {
this.serializationStrategy = serializationStrategy;
}
public void setPrefix(String prefix) {
this.prefix = prefix;
}
private void loadRedisConnectionMethods_2_0() {
this.redisConnectionSet_2_0 = ReflectionUtils.findMethod(
RedisConnection.class, "set", byte[].class, byte[].class);
}
private RedisConnection getConnection() {
return connectionFactory.getConnection();
}
private byte[] serialize(Object object) {
return serializationStrategy.serialize(object);
}
private byte[] serializeKey(String object) {
return serialize(prefix + object);
}
private OAuth2AccessToken deserializeAccessToken(byte[] bytes) {
return serializationStrategy.deserialize(bytes, OAuth2AccessToken.class);
}
private OAuth2Authentication deserializeAuthentication(byte[] bytes) {
return serializationStrategy.deserialize(bytes, OAuth2Authentication.class);
}
private OAuth2RefreshToken deserializeRefreshToken(byte[] bytes) {
return serializationStrategy.deserialize(bytes, OAuth2RefreshToken.class);
}
private ClientDetails deserializeClientDetails(byte[] bytes) {
return serializationStrategy.deserialize(bytes, ClientDetails.class);
}
private byte[] serialize(String string) {
return serializationStrategy.serialize(string);
}
private String deserializeString(byte[] bytes) {
return serializationStrategy.deserializeString(bytes);
}
@Override
public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
String key = authenticationKeyGenerator.extractKey(authentication);
byte[] serializedKey = serializeKey(AUTH_TO_ACCESS + key);
byte[] bytes;
RedisConnection conn = getConnection();
try {
bytes = conn.get(serializedKey);
} finally {
conn.close();
}
OAuth2AccessToken accessToken = deserializeAccessToken(bytes);
if (accessToken != null) {
OAuth2Authentication storedAuthentication = readAuthentication(accessToken.getValue());
if ((storedAuthentication == null || !key.equals(authenticationKeyGenerator.extractKey(storedAuthentication)))) {
// Keep the stores consistent (maybe the same user is
// represented by this authentication but the details have
// changed)
storeAccessToken(accessToken, authentication);
}
}
return accessToken;
}
@Override
public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
OAuth2Authentication auth2Authentication = readAuthentication(token.getValue());
//是否开启token续签
boolean isRenew = securityProperties.getAuth().getRenew().getEnable();
if (isRenew && auth2Authentication != null) {
OAuth2Request clientAuth = auth2Authentication.getOAuth2Request();
//判断当前应用是否需要自动续签
if (checkRenewClientId(clientAuth.getClientId())) {
//获取过期时长
int validitySeconds = getAccessTokenValiditySeconds(clientAuth.getClientId());
if (validitySeconds > 0) {
double expiresRatio = token.getExpiresIn() / (double)validitySeconds;
//判断是否需要续签,当前剩余时间小于过期时长的50%则续签
if (expiresRatio <= securityProperties.getAuth().getRenew().getTimeRatio()) {
//更新AccessToken过期时间
DefaultOAuth2AccessToken oAuth2AccessToken = (DefaultOAuth2AccessToken) token;
oAuth2AccessToken.setExpiration(new Date(System.currentTimeMillis() + (validitySeconds * 1000L)));
storeAccessToken(oAuth2AccessToken, auth2Authentication, true);
}
}
}
}
return auth2Authentication;
}
/**
* 判断应用自动续签是否满足白名单和黑名单的过滤逻辑
* @param clientId 应用id
* @return 是否满足
*/
private boolean checkRenewClientId(String clientId) {
boolean result = true;
//白名单
List includeClientIds = securityProperties.getAuth().getRenew().getIncludeClientIds();
//黑名单
List exclusiveClientIds = securityProperties.getAuth().getRenew().getExclusiveClientIds();
if (includeClientIds.size() > 0) {
result = includeClientIds.contains(clientId);
} else if(exclusiveClientIds.size() > 0) {
result = !exclusiveClientIds.contains(clientId);
}
return result;
}
/**
* 获取token的总有效时长
* @param clientId 应用id
*/
private int getAccessTokenValiditySeconds(String clientId) {
RedisConnection conn = getConnection();
byte[] bytes;
try {
bytes = conn.get(serializeKey(SecurityConstants.CACHE_CLIENT_KEY + ":" + clientId));
} finally {
conn.close();
}
if (bytes != null) {
ClientDetails clientDetails = deserializeClientDetails(bytes);
if (clientDetails.getAccessTokenValiditySeconds() != null) {
return clientDetails.getAccessTokenValiditySeconds();
}
}
//返回默认值
return SecurityConstants.ACCESS_TOKEN_VALIDITY_SECONDS;
}
@Override
public OAuth2Authentication readAuthentication(String token) {
byte[] bytes;
RedisConnection conn = getConnection();
try {
bytes = conn.get(serializeKey(SecurityConstants.REDIS_TOKEN_AUTH + token));
} finally {
conn.close();
}
return deserializeAuthentication(bytes);
}
@Override
public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken token) {
return readAuthenticationForRefreshToken(token.getValue());
}
public OAuth2Authentication readAuthenticationForRefreshToken(String token) {
RedisConnection conn = getConnection();
try {
byte[] bytes = conn.get(serializeKey(REFRESH_AUTH + token));
return deserializeAuthentication(bytes);
} finally {
conn.close();
}
}
@Override
public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
storeAccessToken(token, authentication, false);
}
/**
* 存储token
* @param isRenew 是否续签
*/
private void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication, boolean isRenew) {
byte[] serializedAccessToken = serialize(token);
byte[] serializedAuth = serialize(authentication);
byte[] accessKey = serializeKey(ACCESS + token.getValue());
byte[] authKey = serializeKey(SecurityConstants.REDIS_TOKEN_AUTH + token.getValue());
byte[] authToAccessKey = serializeKey(AUTH_TO_ACCESS + authenticationKeyGenerator.extractKey(authentication));
byte[] approvalKey = serializeKey(SecurityConstants.REDIS_UNAME_TO_ACCESS + getApprovalKey(authentication));
byte[] clientId = serializeKey(SecurityConstants.REDIS_CLIENT_ID_TO_ACCESS + authentication.getOAuth2Request().getClientId());
RedisConnection conn = getConnection();
try {
byte[] oldAccessToken = conn.get(accessKey);
//如果token已存在,并且不是续签的话直接返回
if (!isRenew && oldAccessToken != null) {
return;
}
conn.openPipeline();
if (springDataRedis_2_0) {
try {
this.redisConnectionSet_2_0.invoke(conn, accessKey, serializedAccessToken);
this.redisConnectionSet_2_0.invoke(conn, authKey, serializedAuth);
this.redisConnectionSet_2_0.invoke(conn, authToAccessKey, serializedAccessToken);
} catch (Exception ex) {
throw new RuntimeException(ex);
}
} else {
conn.set(accessKey, serializedAccessToken);
conn.set(authKey, serializedAuth);
conn.set(authToAccessKey, serializedAccessToken);
}
//如果是续签token,需要先删除集合里旧的值
if (oldAccessToken != null) {
if (!authentication.isClientOnly()) {
conn.lRem(approvalKey, 1, oldAccessToken);
}
conn.lRem(clientId, 1, oldAccessToken);
}
if (!authentication.isClientOnly()) {
conn.rPush(approvalKey, serializedAccessToken);
}
conn.rPush(clientId, serializedAccessToken);
if (token.getExpiration() != null) {
// int seconds = token.getExpiresIn();
// conn.expire(accessKey, seconds);
// conn.expire(authKey, seconds);
// conn.expire(authToAccessKey, seconds);
// conn.expire(clientId, seconds);
// conn.expire(approvalKey, seconds);
}
//关闭Token刷新
// OAuth2RefreshToken refreshToken = token.getRefreshToken();
// if (refreshToken != null && refreshToken.getValue() != null) {
// byte[] refresh = serialize(token.getRefreshToken().getValue());
// byte[] auth = serialize(token.getValue());
// byte[] refreshToAccessKey = serializeKey(REFRESH_TO_ACCESS + token.getRefreshToken().getValue());
// byte[] accessToRefreshKey = serializeKey(ACCESS_TO_REFRESH + token.getValue());
// if (springDataRedis_2_0) {
// try {
// this.redisConnectionSet_2_0.invoke(conn, refreshToAccessKey, auth);
// this.redisConnectionSet_2_0.invoke(conn, accessToRefreshKey, refresh);
// } catch (Exception ex) {
// throw new RuntimeException(ex);
// }
// } else {
// conn.set(refreshToAccessKey, auth);
// conn.set(accessToRefreshKey, refresh);
// }
// expireRefreshToken(refreshToken, conn, refreshToAccessKey, accessToRefreshKey);
// }
conn.closePipeline();
} finally {
conn.close();
}
}
private static String getApprovalKey(OAuth2Authentication authentication) {
String userName = authentication.getUserAuthentication() == null ? ""
: authentication.getUserAuthentication().getName();
return getApprovalKey(authentication.getOAuth2Request().getClientId(), userName);
}
private static String getApprovalKey(String clientId, String userName) {
return clientId + (userName == null ? "" : ":" + userName);
}
@Override
public void removeAccessToken(OAuth2AccessToken accessToken) {
removeAccessToken(accessToken.getValue());
}
@Override
public OAuth2AccessToken readAccessToken(String tokenValue) {
byte[] key = serializeKey(ACCESS + tokenValue);
byte[] bytes;
RedisConnection conn = getConnection();
try {
bytes = conn.get(key);
} finally {
conn.close();
}
return deserializeAccessToken(bytes);
}
public void removeAccessToken(String tokenValue) {
byte[] accessKey = serializeKey(ACCESS + tokenValue);
byte[] authKey = serializeKey(SecurityConstants.REDIS_TOKEN_AUTH + tokenValue);
byte[] accessToRefreshKey = serializeKey(ACCESS_TO_REFRESH + tokenValue);
RedisConnection conn = getConnection();
try {
conn.openPipeline();
conn.get(accessKey);
conn.get(authKey);
conn.del(accessKey);
conn.del(accessToRefreshKey);
// Don't remove the refresh token - it's up to the caller to do that
conn.del(authKey);
List